go.uber.org/yarpc@v1.72.1/transport/grpc/headers_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package grpc
    22  
    23  import (
    24  	"testing"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  	"go.uber.org/yarpc/api/transport"
    29  	"go.uber.org/yarpc/yarpcerrors"
    30  	"google.golang.org/grpc/metadata"
    31  )
    32  
    33  func TestMetadataToTransportRequest(t *testing.T) {
    34  	t.Parallel()
    35  	tests := []struct {
    36  		Name             string
    37  		MD               metadata.MD
    38  		TransportRequest *transport.Request
    39  		Error            error
    40  	}{
    41  		{
    42  			Name: "Basic",
    43  			MD: metadata.Pairs(
    44  				CallerHeader, "example-caller",
    45  				ServiceHeader, "example-service",
    46  				ShardKeyHeader, "example-shard-key",
    47  				RoutingKeyHeader, "example-routing-key",
    48  				RoutingDelegateHeader, "example-routing-delegate",
    49  				EncodingHeader, "example-encoding",
    50  				CallerProcedureHeader, "example-caller-procedure",
    51  				"foo", "bar",
    52  				"baz", "bat",
    53  			),
    54  			TransportRequest: &transport.Request{
    55  				Caller:          "example-caller",
    56  				Service:         "example-service",
    57  				ShardKey:        "example-shard-key",
    58  				RoutingKey:      "example-routing-key",
    59  				RoutingDelegate: "example-routing-delegate",
    60  				Encoding:        "example-encoding",
    61  				CallerProcedure: "example-caller-procedure",
    62  				Headers: transport.HeadersFromMap(map[string]string{
    63  					"foo": "bar",
    64  					"baz": "bat",
    65  				}),
    66  			},
    67  		},
    68  		{
    69  			Name: "Content-type",
    70  			MD: metadata.Pairs(
    71  				CallerHeader, "example-caller",
    72  				ServiceHeader, "example-service",
    73  				ShardKeyHeader, "example-shard-key",
    74  				RoutingKeyHeader, "example-routing-key",
    75  				RoutingDelegateHeader, "example-routing-delegate",
    76  				contentTypeHeader, "application/grpc+example-encoding",
    77  				"foo", "bar",
    78  				"baz", "bat",
    79  			),
    80  			TransportRequest: &transport.Request{
    81  				Caller:          "example-caller",
    82  				Service:         "example-service",
    83  				ShardKey:        "example-shard-key",
    84  				RoutingKey:      "example-routing-key",
    85  				RoutingDelegate: "example-routing-delegate",
    86  				Encoding:        "example-encoding",
    87  				Headers: transport.HeadersFromMap(map[string]string{
    88  					"foo": "bar",
    89  					"baz": "bat",
    90  				}),
    91  			},
    92  		},
    93  		{
    94  			Name: "Content-type overridden",
    95  			MD: metadata.Pairs(
    96  				CallerHeader, "example-caller",
    97  				ServiceHeader, "example-service",
    98  				ShardKeyHeader, "example-shard-key",
    99  				RoutingKeyHeader, "example-routing-key",
   100  				RoutingDelegateHeader, "example-routing-delegate",
   101  				EncodingHeader, "example-encoding-override",
   102  				contentTypeHeader, "application/grpc+example-encoding",
   103  				"foo", "bar",
   104  				"baz", "bat",
   105  			),
   106  			TransportRequest: &transport.Request{
   107  				Caller:          "example-caller",
   108  				Service:         "example-service",
   109  				ShardKey:        "example-shard-key",
   110  				RoutingKey:      "example-routing-key",
   111  				RoutingDelegate: "example-routing-delegate",
   112  				Encoding:        "example-encoding-override",
   113  				Headers: transport.HeadersFromMap(map[string]string{
   114  					"foo": "bar",
   115  					"baz": "bat",
   116  				}),
   117  			},
   118  		},
   119  	}
   120  	for _, tt := range tests {
   121  		t.Run(tt.Name, func(t *testing.T) {
   122  			transportRequest, err := metadataToTransportRequest(tt.MD)
   123  			require.Equal(t, tt.Error, err)
   124  			require.Equal(t, tt.TransportRequest, transportRequest)
   125  		})
   126  	}
   127  }
   128  
   129  func TestTransportRequestToMetadata(t *testing.T) {
   130  	t.Parallel()
   131  	for _, tt := range []struct {
   132  		Name             string
   133  		MD               metadata.MD
   134  		TransportRequest *transport.Request
   135  		Error            error
   136  	}{
   137  		{
   138  			Name: "Basic",
   139  			MD: metadata.Pairs(
   140  				CallerHeader, "example-caller",
   141  				ServiceHeader, "example-service",
   142  				ShardKeyHeader, "example-shard-key",
   143  				RoutingKeyHeader, "example-routing-key",
   144  				RoutingDelegateHeader, "example-routing-delegate",
   145  				CallerProcedureHeader, "example-caller-procedure",
   146  				EncodingHeader, "example-encoding",
   147  				"foo", "bar",
   148  				"baz", "bat",
   149  			),
   150  			TransportRequest: &transport.Request{
   151  				Caller:          "example-caller",
   152  				Service:         "example-service",
   153  				ShardKey:        "example-shard-key",
   154  				RoutingKey:      "example-routing-key",
   155  				RoutingDelegate: "example-routing-delegate",
   156  				CallerProcedure: "example-caller-procedure",
   157  				Encoding:        "example-encoding",
   158  				Headers: transport.HeadersFromMap(map[string]string{
   159  					"foo": "bar",
   160  					"baz": "bat",
   161  				}),
   162  			},
   163  		},
   164  		{
   165  			Name: "Reserved header key in application headers",
   166  			MD:   metadata.Pairs(),
   167  			TransportRequest: &transport.Request{
   168  				Headers: transport.HeadersFromMap(map[string]string{
   169  					CallerHeader: "example-caller",
   170  				}),
   171  			},
   172  			Error: yarpcerrors.InvalidArgumentErrorf("cannot use reserved header in application headers: %s", CallerHeader),
   173  		},
   174  	} {
   175  		t.Run(tt.Name, func(t *testing.T) {
   176  			md, err := transportRequestToMetadata(tt.TransportRequest)
   177  			require.Equal(t, tt.Error, err)
   178  			require.Equal(t, tt.MD, md)
   179  		})
   180  	}
   181  }
   182  
   183  func TestGetContentSubtype(t *testing.T) {
   184  	tests := []struct {
   185  		contentType    string
   186  		contentSubtype string
   187  	}{
   188  		{"application/grpc", ""},
   189  		{"application/grpc+proto", "proto"},
   190  		{"application/grpc;proto", "proto"},
   191  		{"application/grpc-proto", ""},
   192  	}
   193  	for _, tt := range tests {
   194  		assert.Equal(t, tt.contentSubtype, getContentSubtype(tt.contentType))
   195  	}
   196  }
   197  
   198  func TestIsReserved(t *testing.T) {
   199  	assert.True(t, isReserved(CallerHeader))
   200  	assert.True(t, isReserved(ServiceHeader))
   201  	assert.True(t, isReserved(ShardKeyHeader))
   202  	assert.True(t, isReserved(RoutingKeyHeader))
   203  	assert.True(t, isReserved(RoutingDelegateHeader))
   204  	assert.True(t, isReserved(EncodingHeader))
   205  	assert.True(t, isReserved("rpc-foo"))
   206  }
   207  
   208  func TestMDReadWriterDuplicateKey(t *testing.T) {
   209  	const key = "uber-trace-id"
   210  	md := map[string][]string{
   211  		key: {"to-override"},
   212  	}
   213  	mdRW := mdReadWriter(md)
   214  	mdRW.Set(key, "overwritten")
   215  	assert.Equal(t, []string{"overwritten"}, md[key], "expected overwritten values")
   216  }
   217  
   218  func TestGetApplicationHeaders(t *testing.T) {
   219  	tests := []struct {
   220  		msg         string
   221  		meta        metadata.MD
   222  		wantHeaders map[string]string
   223  		wantErr     string
   224  	}{
   225  		{
   226  			msg:         "nil",
   227  			meta:        nil,
   228  			wantHeaders: nil,
   229  		},
   230  		{
   231  			msg:         "empty",
   232  			meta:        metadata.MD{},
   233  			wantHeaders: nil,
   234  		},
   235  		{
   236  			msg: "success",
   237  			meta: metadata.MD{
   238  				"rpc-service":         []string{"foo"}, // reserved header
   239  				"test-header-empty":   []string{},      // no value
   240  				"test-header-valid-1": []string{"test-value-1"},
   241  				"test-Header-Valid-2": []string{"test-value-2"},
   242  			},
   243  			wantHeaders: map[string]string{
   244  				"test-header-valid-1": "test-value-1",
   245  				"test-header-valid-2": "test-value-2",
   246  			},
   247  		},
   248  		{
   249  			msg: "error: multiple values for one header",
   250  			meta: metadata.MD{
   251  				"test-header-valid": []string{"test-value"},
   252  				"test-header-dup":   []string{"test-value-1", "test-value-2"},
   253  			},
   254  			wantErr: "header has more than one value: test-header-dup:[test-value-1 test-value-2]",
   255  		},
   256  	}
   257  
   258  	for _, tt := range tests {
   259  		t.Run(tt.msg, func(t *testing.T) {
   260  			got, err := getApplicationHeaders(tt.meta)
   261  			if tt.wantErr != "" {
   262  				require.Error(t, err)
   263  				assert.Contains(t, err.Error(), tt.wantErr, "unexpecte error message")
   264  				return
   265  			}
   266  			require.NoError(t, err, "failed to extract application headers")
   267  			assert.Equal(t, tt.wantHeaders, got.Items(), "unexpected headers")
   268  		})
   269  	}
   270  }