go.uber.org/yarpc@v1.72.1/api/transport/request_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package transport_test
    22  
    23  import (
    24  	"context"
    25  	"strings"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/stretchr/testify/assert"
    30  	"go.uber.org/yarpc/api/transport"
    31  	"go.uber.org/zap/zapcore"
    32  )
    33  
    34  func TestValidator(t *testing.T) {
    35  	tests := []struct {
    36  		req           *transport.Request
    37  		transportType transport.Type
    38  		ttl           time.Duration
    39  
    40  		wantMissingParams []string
    41  	}{
    42  		{
    43  			// No error
    44  			req: &transport.Request{
    45  				Caller:    "caller",
    46  				Service:   "service",
    47  				Encoding:  "raw",
    48  				Procedure: "hello",
    49  			},
    50  			ttl: time.Second,
    51  		},
    52  		{
    53  			// encoding is not required
    54  			req: &transport.Request{
    55  				Caller:    "caller",
    56  				Service:   "service",
    57  				Procedure: "hello",
    58  			},
    59  			wantMissingParams: []string{"encoding"},
    60  		},
    61  		{
    62  			req: &transport.Request{
    63  				Service:   "service",
    64  				Procedure: "hello",
    65  				Encoding:  "raw",
    66  			},
    67  			wantMissingParams: []string{"caller"},
    68  		},
    69  		{
    70  			req: &transport.Request{
    71  				Caller:    "caller",
    72  				Procedure: "hello",
    73  				Encoding:  "raw",
    74  			},
    75  			wantMissingParams: []string{"service"},
    76  		},
    77  		{
    78  			req: &transport.Request{
    79  				Caller:   "caller",
    80  				Service:  "service",
    81  				Encoding: "raw",
    82  			},
    83  			wantMissingParams: []string{"procedure"},
    84  		},
    85  		{
    86  			req: &transport.Request{
    87  				Caller:    "caller",
    88  				Service:   "service",
    89  				Procedure: "hello",
    90  				Encoding:  "raw",
    91  			},
    92  			transportType:     transport.Unary,
    93  			wantMissingParams: []string{"TTL"},
    94  		},
    95  		{
    96  			req:               &transport.Request{},
    97  			wantMissingParams: []string{"encoding", "caller", "service", "procedure"},
    98  		},
    99  	}
   100  
   101  	for _, tt := range tests {
   102  		ctx := context.Background()
   103  		err := transport.ValidateRequest(tt.req)
   104  
   105  		if err == nil && tt.transportType == transport.Unary {
   106  			var cancel func()
   107  
   108  			if tt.ttl != 0 {
   109  				ctx, cancel = context.WithTimeout(ctx, tt.ttl)
   110  				defer cancel()
   111  			}
   112  
   113  			err = transport.ValidateRequestContext(ctx)
   114  		}
   115  
   116  		if len(tt.wantMissingParams) > 0 {
   117  			if assert.Error(t, err) {
   118  				for _, wantMissingParam := range tt.wantMissingParams {
   119  					assert.Contains(t, err.Error(), wantMissingParam)
   120  				}
   121  			}
   122  		} else {
   123  			assert.NoError(t, err)
   124  		}
   125  	}
   126  }
   127  
   128  func TestRequestLogMarshaling(t *testing.T) {
   129  	r := &transport.Request{
   130  		Caller:          "caller",
   131  		Service:         "service",
   132  		Transport:       "transport",
   133  		Encoding:        "raw",
   134  		Procedure:       "procedure",
   135  		Headers:         transport.NewHeaders().With("password", "super-secret"),
   136  		ShardKey:        "shard01",
   137  		RoutingKey:      "routing-key",
   138  		RoutingDelegate: "routing-delegate",
   139  		CallerProcedure: "caller-procedure",
   140  		Body:            strings.NewReader("body"),
   141  	}
   142  	enc := zapcore.NewMapObjectEncoder()
   143  	assert.NoError(t, r.MarshalLogObject(enc), "Unexpected error marshaling request.")
   144  	assert.Equal(t, map[string]interface{}{
   145  		"caller":          "caller",
   146  		"service":         "service",
   147  		"transport":       "transport",
   148  		"encoding":        "raw",
   149  		"procedure":       "procedure",
   150  		"shardKey":        "shard01",
   151  		"routingKey":      "routing-key",
   152  		"routingDelegate": "routing-delegate",
   153  		"callerProcedure": "caller-procedure",
   154  	}, enc.Fields, "Unexpected output after marshaling request.")
   155  }
   156  
   157  func TestRequestMetaToRequestConversionAndBack(t *testing.T) {
   158  	reqMeta := &transport.RequestMeta{
   159  		Caller:          "caller",
   160  		Service:         "service",
   161  		Transport:       "transport",
   162  		Encoding:        "raw",
   163  		Procedure:       "hello",
   164  		Headers:         transport.NewHeaders().With("key", "val"),
   165  		ShardKey:        "shard",
   166  		RoutingKey:      "rk",
   167  		RoutingDelegate: "rd",
   168  		CallerProcedure: "cp",
   169  	}
   170  
   171  	req := reqMeta.ToRequest()
   172  
   173  	assert.Equal(t, reqMeta, req.ToRequestMeta())
   174  }
   175  
   176  func TestNilRequestMetaToRequestConversion(t *testing.T) {
   177  	var reqMeta *transport.RequestMeta
   178  
   179  	assert.Equal(t, &transport.Request{}, reqMeta.ToRequest())
   180  }