go.uber.org/yarpc@v1.72.1/transport/tchannel/header_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package tchannel 22 23 import ( 24 "bytes" 25 "errors" 26 "io/ioutil" 27 "testing" 28 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/require" 31 "github.com/uber/tchannel-go" 32 "go.uber.org/yarpc/api/transport" 33 "go.uber.org/yarpc/yarpcerrors" 34 ) 35 36 func TestEncodeAndDecodeHeaders(t *testing.T) { 37 tests := []struct { 38 bytes []byte 39 headers map[string]string 40 }{ 41 {[]byte{0x00, 0x00}, nil}, 42 { 43 []byte{ 44 0x00, 0x01, // 1 header 45 46 0x00, 0x05, // length = 5 47 'h', 'e', 'l', 'l', 'o', 48 49 0x00, 0x05, // lengtth = 5 50 'w', 'o', 'r', 'l', 'd', 51 }, 52 map[string]string{"hello": "world"}, 53 }, 54 } 55 56 for _, tt := range tests { 57 headers := transport.HeadersFromMap(tt.headers) 58 assert.Equal(t, tt.bytes, encodeHeaders(tt.headers)) 59 60 result, err := decodeHeaders(bytes.NewReader(tt.bytes)) 61 if assert.NoError(t, err) { 62 assert.Equal(t, headers, result) 63 } 64 } 65 } 66 67 func TestAddCallerProcedureHeader(t *testing.T) { 68 for _, tt := range []struct { 69 desc string 70 treq transport.Request 71 headers map[string]string 72 expectedHeaders map[string]string 73 }{ 74 { 75 desc: "valid_callerProcedure_and_valid_header", 76 treq: transport.Request{CallerProcedure: "ABC"}, 77 headers: map[string]string{"header": "value"}, 78 expectedHeaders: map[string]string{ 79 CallerProcedureHeader: "ABC", 80 "header": "value", 81 }, 82 }, 83 { 84 desc: "valid_callerProcedure_and_empty_header", 85 treq: transport.Request{CallerProcedure: "ABC"}, 86 headers: nil, 87 expectedHeaders: map[string]string{CallerProcedureHeader: "ABC"}, 88 }, 89 { 90 desc: "empty_callerProcedure_and_empty_header", 91 treq: transport.Request{}, 92 headers: nil, 93 expectedHeaders: nil, 94 }, 95 { 96 desc: "empty_callerProcedure_and_valid_header", 97 treq: transport.Request{}, 98 headers: map[string]string{"header": "value"}, 99 expectedHeaders: map[string]string{"header": "value"}, 100 }, 101 } { 102 t.Run(tt.desc, func(t *testing.T) { 103 headers := requestCallerProcedureToHeader(&tt.treq, tt.headers) 104 assert.Equal(t, tt.expectedHeaders, headers) 105 }) 106 } 107 } 108 109 func TestMoveCallerProcedureToRequest(t *testing.T) { 110 for _, tt := range []struct { 111 desc string 112 treq transport.Request 113 headers map[string]string 114 expectedTreq transport.Request 115 expectedHeaders map[string]string 116 }{ 117 { 118 desc: "no_callerProcedureReq_in_headers", 119 treq: transport.Request{}, 120 headers: map[string]string{"header": "value"}, 121 expectedTreq: transport.Request{}, 122 expectedHeaders: map[string]string{"header": "value"}, 123 }, 124 { 125 desc: "callerProcedureReq_set_in_headers", 126 treq: transport.Request{}, 127 headers: map[string]string{ 128 "header": "value", 129 CallerProcedureHeader: "ABC", 130 }, 131 expectedTreq: transport.Request{CallerProcedure: "ABC"}, 132 expectedHeaders: map[string]string{"header": "value"}, 133 }, 134 } { 135 t.Run(tt.desc, func(t *testing.T) { 136 headers := transport.HeadersFromMap(tt.headers) 137 treq := headerCallerProcedureToRequest(&tt.treq, &headers) 138 assert.Equal(t, tt.expectedTreq, *treq) 139 assert.Equal(t, transport.HeadersFromMap(tt.expectedHeaders), headers) 140 }) 141 } 142 } 143 func TestDecodeHeaderErrors(t *testing.T) { 144 tests := [][]byte{ 145 {0x00, 0x01}, 146 { 147 0x00, 0x01, 148 0x00, 0x02, 'a', 149 0x00, 0x01, 'b', 150 }, 151 } 152 153 for _, tt := range tests { 154 _, err := decodeHeaders(bytes.NewReader(tt)) 155 assert.Error(t, err) 156 } 157 } 158 159 func TestReadAndWriteHeaders(t *testing.T) { 160 tests := []struct { 161 format tchannel.Format 162 163 // the headers are serialized in an undefined order so the encoding 164 // must be one of the following 165 bytes []byte 166 orBytes []byte 167 168 headers map[string]string 169 }{ 170 { 171 tchannel.Raw, 172 []byte{ 173 0x00, 0x02, 174 0x00, 0x01, 'a', 0x00, 0x01, '1', 175 0x00, 0x01, 'b', 0x00, 0x01, '2', 176 }, 177 []byte{ 178 0x00, 0x02, 179 0x00, 0x01, 'b', 0x00, 0x01, '2', 180 0x00, 0x01, 'a', 0x00, 0x01, '1', 181 }, 182 map[string]string{"a": "1", "b": "2"}, 183 }, 184 { 185 tchannel.JSON, 186 []byte(`{"a":"1","b":"2"}` + "\n"), 187 []byte(`{"b":"2","a":"1"}` + "\n"), 188 map[string]string{"a": "1", "b": "2"}, 189 }, 190 { 191 tchannel.Thrift, 192 []byte{ 193 0x00, 0x02, 194 0x00, 0x01, 'a', 0x00, 0x01, '1', 195 0x00, 0x01, 'b', 0x00, 0x01, '2', 196 }, 197 []byte{ 198 0x00, 0x02, 199 0x00, 0x01, 'b', 0x00, 0x01, '2', 200 0x00, 0x01, 'a', 0x00, 0x01, '1', 201 }, 202 map[string]string{"a": "1", "b": "2"}, 203 }, 204 } 205 206 for _, tt := range tests { 207 headers := transport.HeadersFromMap(tt.headers) 208 209 buffer := newBufferArgWriter() 210 err := writeHeaders(tt.format, tt.headers, nil, func() (tchannel.ArgWriter, error) { 211 return buffer, nil 212 }) 213 require.NoError(t, err) 214 215 // Result must match either tt.bytes or tt.orBytes. 216 if !bytes.Equal(tt.bytes, buffer.Bytes()) { 217 assert.Equal(t, tt.orBytes, buffer.Bytes(), "failed for %v", tt.format) 218 } 219 220 result, err := readHeaders(tt.format, func() (tchannel.ArgReader, error) { 221 reader := ioutil.NopCloser(bytes.NewReader(buffer.Bytes())) 222 return tchannel.ArgReader(reader), nil 223 }) 224 require.NoError(t, err) 225 assert.Equal(t, headers, result, "failed for %v", tt.format) 226 } 227 } 228 229 func TestReadHeadersFailure(t *testing.T) { 230 _, err := readHeaders(tchannel.Raw, func() (tchannel.ArgReader, error) { 231 return nil, errors.New("great sadness") 232 }) 233 require.Error(t, err) 234 } 235 236 func TestWriteHeaders(t *testing.T) { 237 tests := []struct { 238 msg string 239 // the headers are serialized in an undefined order so the encoding 240 // must be one of bytes or orBytes 241 bytes []byte 242 orBytes []byte 243 headers map[string]string 244 tracingBaggage map[string]string 245 }{ 246 { 247 "lowercase header", 248 []byte{ 249 0x00, 0x02, 250 0x00, 0x01, 'a', 0x00, 0x01, '1', 251 0x00, 0x01, 'b', 0x00, 0x01, '2', 252 }, 253 []byte{ 254 0x00, 0x02, 255 0x00, 0x01, 'b', 0x00, 0x01, '2', 256 0x00, 0x01, 'a', 0x00, 0x01, '1', 257 }, 258 map[string]string{"a": "1", "b": "2"}, 259 nil, /* tracingBaggage */ 260 }, 261 { 262 "mixed case header", 263 []byte{ 264 0x00, 0x02, 265 0x00, 0x01, 'A', 0x00, 0x01, '1', 266 0x00, 0x01, 'b', 0x00, 0x01, '2', 267 }, 268 []byte{ 269 0x00, 0x02, 270 0x00, 0x01, 'b', 0x00, 0x01, '2', 271 0x00, 0x01, 'A', 0x00, 0x01, '1', 272 }, 273 map[string]string{"A": "1", "b": "2"}, 274 nil, /* tracingBaggage */ 275 }, 276 { 277 "keys only differ by case", 278 []byte{ 279 0x00, 0x02, 280 0x00, 0x01, 'A', 0x00, 0x01, '1', 281 0x00, 0x01, 'a', 0x00, 0x01, '2', 282 }, 283 []byte{ 284 0x00, 0x02, 285 0x00, 0x01, 'a', 0x00, 0x01, '2', 286 0x00, 0x01, 'A', 0x00, 0x01, '1', 287 }, 288 map[string]string{"A": "1", "a": "2"}, 289 nil, /* tracingBaggage */ 290 }, 291 { 292 "tracing bagger header", 293 []byte{ 294 0x00, 0x02, 295 0x00, 0x01, 'a', 0x00, 0x01, '1', 296 0x00, 0x01, 'b', 0x00, 0x01, '2', 297 }, 298 []byte{ 299 0x00, 0x02, 300 0x00, 0x01, 'b', 0x00, 0x01, '2', 301 0x00, 0x01, 'a', 0x00, 0x01, '1', 302 }, 303 map[string]string{"b": "2"}, 304 map[string]string{"a": "1"}, 305 }, 306 } 307 308 for _, tt := range tests { 309 t.Run(tt.msg, func(t *testing.T) { 310 buffer := newBufferArgWriter() 311 err := writeHeaders(tchannel.Raw, tt.headers, tt.tracingBaggage, func() (tchannel.ArgWriter, error) { 312 return buffer, nil 313 }) 314 require.NoError(t, err) 315 // Result must match either tt.bytes or tt.orBytes. 316 if !bytes.Equal(tt.bytes, buffer.Bytes()) { 317 assert.Equal(t, tt.orBytes, buffer.Bytes()) 318 } 319 }) 320 } 321 } 322 323 func TestValidateServiceHeaders(t *testing.T) { 324 tests := []struct { 325 name string 326 requestService string 327 responseService string 328 err bool 329 }{ 330 { 331 name: "match", 332 requestService: "service", 333 responseService: "service", 334 }, 335 { 336 name: "match empty", 337 }, 338 { 339 name: "match - no response", 340 requestService: "service", 341 }, 342 { 343 name: "no match", 344 requestService: "foo", 345 responseService: "bar", 346 err: true, 347 }, 348 } 349 350 for _, tt := range tests { 351 t.Run(tt.name, func(t *testing.T) { 352 if !tt.err { 353 assert.NoError(t, validateServiceName(tt.requestService, tt.responseService)) 354 355 } else { 356 err := validateServiceName(tt.requestService, tt.responseService) 357 require.Error(t, err) 358 assert.True(t, yarpcerrors.IsInternal(err), "expected yarpc.InternalError") 359 } 360 }) 361 } 362 }