go.uber.org/yarpc@v1.72.1/transport/tchannel/header_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package tchannel
    22  
    23  import (
    24  	"bytes"
    25  	"errors"
    26  	"io/ioutil"
    27  	"testing"
    28  
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  	"github.com/uber/tchannel-go"
    32  	"go.uber.org/yarpc/api/transport"
    33  	"go.uber.org/yarpc/yarpcerrors"
    34  )
    35  
    36  func TestEncodeAndDecodeHeaders(t *testing.T) {
    37  	tests := []struct {
    38  		bytes   []byte
    39  		headers map[string]string
    40  	}{
    41  		{[]byte{0x00, 0x00}, nil},
    42  		{
    43  			[]byte{
    44  				0x00, 0x01, // 1 header
    45  
    46  				0x00, 0x05, // length = 5
    47  				'h', 'e', 'l', 'l', 'o',
    48  
    49  				0x00, 0x05, // lengtth = 5
    50  				'w', 'o', 'r', 'l', 'd',
    51  			},
    52  			map[string]string{"hello": "world"},
    53  		},
    54  	}
    55  
    56  	for _, tt := range tests {
    57  		headers := transport.HeadersFromMap(tt.headers)
    58  		assert.Equal(t, tt.bytes, encodeHeaders(tt.headers))
    59  
    60  		result, err := decodeHeaders(bytes.NewReader(tt.bytes))
    61  		if assert.NoError(t, err) {
    62  			assert.Equal(t, headers, result)
    63  		}
    64  	}
    65  }
    66  
    67  func TestAddCallerProcedureHeader(t *testing.T) {
    68  	for _, tt := range []struct {
    69  		desc            string
    70  		treq            transport.Request
    71  		headers         map[string]string
    72  		expectedHeaders map[string]string
    73  	}{
    74  		{
    75  			desc:    "valid_callerProcedure_and_valid_header",
    76  			treq:    transport.Request{CallerProcedure: "ABC"},
    77  			headers: map[string]string{"header": "value"},
    78  			expectedHeaders: map[string]string{
    79  				CallerProcedureHeader: "ABC",
    80  				"header":              "value",
    81  			},
    82  		},
    83  		{
    84  			desc:            "valid_callerProcedure_and_empty_header",
    85  			treq:            transport.Request{CallerProcedure: "ABC"},
    86  			headers:         nil,
    87  			expectedHeaders: map[string]string{CallerProcedureHeader: "ABC"},
    88  		},
    89  		{
    90  			desc:            "empty_callerProcedure_and_empty_header",
    91  			treq:            transport.Request{},
    92  			headers:         nil,
    93  			expectedHeaders: nil,
    94  		},
    95  		{
    96  			desc:            "empty_callerProcedure_and_valid_header",
    97  			treq:            transport.Request{},
    98  			headers:         map[string]string{"header": "value"},
    99  			expectedHeaders: map[string]string{"header": "value"},
   100  		},
   101  	} {
   102  		t.Run(tt.desc, func(t *testing.T) {
   103  			headers := requestCallerProcedureToHeader(&tt.treq, tt.headers)
   104  			assert.Equal(t, tt.expectedHeaders, headers)
   105  		})
   106  	}
   107  }
   108  
   109  func TestMoveCallerProcedureToRequest(t *testing.T) {
   110  	for _, tt := range []struct {
   111  		desc            string
   112  		treq            transport.Request
   113  		headers         map[string]string
   114  		expectedTreq    transport.Request
   115  		expectedHeaders map[string]string
   116  	}{
   117  		{
   118  			desc:            "no_callerProcedureReq_in_headers",
   119  			treq:            transport.Request{},
   120  			headers:         map[string]string{"header": "value"},
   121  			expectedTreq:    transport.Request{},
   122  			expectedHeaders: map[string]string{"header": "value"},
   123  		},
   124  		{
   125  			desc: "callerProcedureReq_set_in_headers",
   126  			treq: transport.Request{},
   127  			headers: map[string]string{
   128  				"header":              "value",
   129  				CallerProcedureHeader: "ABC",
   130  			},
   131  			expectedTreq:    transport.Request{CallerProcedure: "ABC"},
   132  			expectedHeaders: map[string]string{"header": "value"},
   133  		},
   134  	} {
   135  		t.Run(tt.desc, func(t *testing.T) {
   136  			headers := transport.HeadersFromMap(tt.headers)
   137  			treq := headerCallerProcedureToRequest(&tt.treq, &headers)
   138  			assert.Equal(t, tt.expectedTreq, *treq)
   139  			assert.Equal(t, transport.HeadersFromMap(tt.expectedHeaders), headers)
   140  		})
   141  	}
   142  }
   143  func TestDecodeHeaderErrors(t *testing.T) {
   144  	tests := [][]byte{
   145  		{0x00, 0x01},
   146  		{
   147  			0x00, 0x01,
   148  			0x00, 0x02, 'a',
   149  			0x00, 0x01, 'b',
   150  		},
   151  	}
   152  
   153  	for _, tt := range tests {
   154  		_, err := decodeHeaders(bytes.NewReader(tt))
   155  		assert.Error(t, err)
   156  	}
   157  }
   158  
   159  func TestReadAndWriteHeaders(t *testing.T) {
   160  	tests := []struct {
   161  		format tchannel.Format
   162  
   163  		// the headers are serialized in an undefined order so the encoding
   164  		// must be one of the following
   165  		bytes   []byte
   166  		orBytes []byte
   167  
   168  		headers map[string]string
   169  	}{
   170  		{
   171  			tchannel.Raw,
   172  			[]byte{
   173  				0x00, 0x02,
   174  				0x00, 0x01, 'a', 0x00, 0x01, '1',
   175  				0x00, 0x01, 'b', 0x00, 0x01, '2',
   176  			},
   177  			[]byte{
   178  				0x00, 0x02,
   179  				0x00, 0x01, 'b', 0x00, 0x01, '2',
   180  				0x00, 0x01, 'a', 0x00, 0x01, '1',
   181  			},
   182  			map[string]string{"a": "1", "b": "2"},
   183  		},
   184  		{
   185  			tchannel.JSON,
   186  			[]byte(`{"a":"1","b":"2"}` + "\n"),
   187  			[]byte(`{"b":"2","a":"1"}` + "\n"),
   188  			map[string]string{"a": "1", "b": "2"},
   189  		},
   190  		{
   191  			tchannel.Thrift,
   192  			[]byte{
   193  				0x00, 0x02,
   194  				0x00, 0x01, 'a', 0x00, 0x01, '1',
   195  				0x00, 0x01, 'b', 0x00, 0x01, '2',
   196  			},
   197  			[]byte{
   198  				0x00, 0x02,
   199  				0x00, 0x01, 'b', 0x00, 0x01, '2',
   200  				0x00, 0x01, 'a', 0x00, 0x01, '1',
   201  			},
   202  			map[string]string{"a": "1", "b": "2"},
   203  		},
   204  	}
   205  
   206  	for _, tt := range tests {
   207  		headers := transport.HeadersFromMap(tt.headers)
   208  
   209  		buffer := newBufferArgWriter()
   210  		err := writeHeaders(tt.format, tt.headers, nil, func() (tchannel.ArgWriter, error) {
   211  			return buffer, nil
   212  		})
   213  		require.NoError(t, err)
   214  
   215  		// Result must match either tt.bytes or tt.orBytes.
   216  		if !bytes.Equal(tt.bytes, buffer.Bytes()) {
   217  			assert.Equal(t, tt.orBytes, buffer.Bytes(), "failed for %v", tt.format)
   218  		}
   219  
   220  		result, err := readHeaders(tt.format, func() (tchannel.ArgReader, error) {
   221  			reader := ioutil.NopCloser(bytes.NewReader(buffer.Bytes()))
   222  			return tchannel.ArgReader(reader), nil
   223  		})
   224  		require.NoError(t, err)
   225  		assert.Equal(t, headers, result, "failed for %v", tt.format)
   226  	}
   227  }
   228  
   229  func TestReadHeadersFailure(t *testing.T) {
   230  	_, err := readHeaders(tchannel.Raw, func() (tchannel.ArgReader, error) {
   231  		return nil, errors.New("great sadness")
   232  	})
   233  	require.Error(t, err)
   234  }
   235  
   236  func TestWriteHeaders(t *testing.T) {
   237  	tests := []struct {
   238  		msg string
   239  		// the headers are serialized in an undefined order so the encoding
   240  		// must be one of bytes or orBytes
   241  		bytes          []byte
   242  		orBytes        []byte
   243  		headers        map[string]string
   244  		tracingBaggage map[string]string
   245  	}{
   246  		{
   247  			"lowercase header",
   248  			[]byte{
   249  				0x00, 0x02,
   250  				0x00, 0x01, 'a', 0x00, 0x01, '1',
   251  				0x00, 0x01, 'b', 0x00, 0x01, '2',
   252  			},
   253  			[]byte{
   254  				0x00, 0x02,
   255  				0x00, 0x01, 'b', 0x00, 0x01, '2',
   256  				0x00, 0x01, 'a', 0x00, 0x01, '1',
   257  			},
   258  			map[string]string{"a": "1", "b": "2"},
   259  			nil, /* tracingBaggage */
   260  		},
   261  		{
   262  			"mixed case header",
   263  			[]byte{
   264  				0x00, 0x02,
   265  				0x00, 0x01, 'A', 0x00, 0x01, '1',
   266  				0x00, 0x01, 'b', 0x00, 0x01, '2',
   267  			},
   268  			[]byte{
   269  				0x00, 0x02,
   270  				0x00, 0x01, 'b', 0x00, 0x01, '2',
   271  				0x00, 0x01, 'A', 0x00, 0x01, '1',
   272  			},
   273  			map[string]string{"A": "1", "b": "2"},
   274  			nil, /* tracingBaggage */
   275  		},
   276  		{
   277  			"keys only differ by case",
   278  			[]byte{
   279  				0x00, 0x02,
   280  				0x00, 0x01, 'A', 0x00, 0x01, '1',
   281  				0x00, 0x01, 'a', 0x00, 0x01, '2',
   282  			},
   283  			[]byte{
   284  				0x00, 0x02,
   285  				0x00, 0x01, 'a', 0x00, 0x01, '2',
   286  				0x00, 0x01, 'A', 0x00, 0x01, '1',
   287  			},
   288  			map[string]string{"A": "1", "a": "2"},
   289  			nil, /* tracingBaggage */
   290  		},
   291  		{
   292  			"tracing bagger header",
   293  			[]byte{
   294  				0x00, 0x02,
   295  				0x00, 0x01, 'a', 0x00, 0x01, '1',
   296  				0x00, 0x01, 'b', 0x00, 0x01, '2',
   297  			},
   298  			[]byte{
   299  				0x00, 0x02,
   300  				0x00, 0x01, 'b', 0x00, 0x01, '2',
   301  				0x00, 0x01, 'a', 0x00, 0x01, '1',
   302  			},
   303  			map[string]string{"b": "2"},
   304  			map[string]string{"a": "1"},
   305  		},
   306  	}
   307  
   308  	for _, tt := range tests {
   309  		t.Run(tt.msg, func(t *testing.T) {
   310  			buffer := newBufferArgWriter()
   311  			err := writeHeaders(tchannel.Raw, tt.headers, tt.tracingBaggage, func() (tchannel.ArgWriter, error) {
   312  				return buffer, nil
   313  			})
   314  			require.NoError(t, err)
   315  			// Result must match either tt.bytes or tt.orBytes.
   316  			if !bytes.Equal(tt.bytes, buffer.Bytes()) {
   317  				assert.Equal(t, tt.orBytes, buffer.Bytes())
   318  			}
   319  		})
   320  	}
   321  }
   322  
   323  func TestValidateServiceHeaders(t *testing.T) {
   324  	tests := []struct {
   325  		name            string
   326  		requestService  string
   327  		responseService string
   328  		err             bool
   329  	}{
   330  		{
   331  			name:            "match",
   332  			requestService:  "service",
   333  			responseService: "service",
   334  		},
   335  		{
   336  			name: "match empty",
   337  		},
   338  		{
   339  			name:           "match - no response",
   340  			requestService: "service",
   341  		},
   342  		{
   343  			name:            "no match",
   344  			requestService:  "foo",
   345  			responseService: "bar",
   346  			err:             true,
   347  		},
   348  	}
   349  
   350  	for _, tt := range tests {
   351  		t.Run(tt.name, func(t *testing.T) {
   352  			if !tt.err {
   353  				assert.NoError(t, validateServiceName(tt.requestService, tt.responseService))
   354  
   355  			} else {
   356  				err := validateServiceName(tt.requestService, tt.responseService)
   357  				require.Error(t, err)
   358  				assert.True(t, yarpcerrors.IsInternal(err), "expected yarpc.InternalError")
   359  			}
   360  		})
   361  	}
   362  }