go.uber.org/yarpc@v1.72.1/api/transport/transporttest/reqres.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package transporttest
    22  
    23  import (
    24  	"bytes"
    25  	"fmt"
    26  	"io/ioutil"
    27  	"reflect"
    28  	"strings"
    29  	"testing"
    30  
    31  	"go.uber.org/yarpc/api/transport"
    32  )
    33  
    34  // RequestMatcher may be used in gomock argument lists to assert that two
    35  // requests match.
    36  //
    37  // Requests are considered to be matching if: all their primitive parameters
    38  // match, the headers of the received request include all the headers from the
    39  // source request, and the contents of the request bodies are the same.
    40  type RequestMatcher struct {
    41  	t    *testing.T
    42  	req  *transport.Request
    43  	body []byte
    44  }
    45  
    46  // NewRequestMatcher constructs a new RequestMatcher from the given testing.T
    47  // and request.
    48  //
    49  // The request's contents are read in their entirety and replaced with a
    50  // bytes.Reader.
    51  func NewRequestMatcher(t *testing.T, r *transport.Request) RequestMatcher {
    52  	body, err := ioutil.ReadAll(r.Body)
    53  	if err != nil {
    54  		t.Fatalf("failed to read request body: %v", err)
    55  	}
    56  
    57  	// restore a copy of the body so that the caller can still use the request
    58  	// object
    59  	r.Body = bytes.NewReader(body)
    60  	return RequestMatcher{t: t, req: r, body: body}
    61  }
    62  
    63  // TODO: Headers like User-Agent, Content-Length, etc. make their way to the
    64  // user-level Request object. For now, we're doing the super set check but we
    65  // should do something more specific once yarpc/yarpc.io#2 is resolved.
    66  
    67  // Matches checks if the given object matches the Request provided in
    68  // NewRequestMatcher.
    69  func (m RequestMatcher) Matches(got interface{}) bool {
    70  	l := m.req
    71  	r, ok := got.(*transport.Request)
    72  	if !ok {
    73  		panic(fmt.Sprintf("expected *transport.Request, got %v", got))
    74  	}
    75  
    76  	if l.Caller != r.Caller {
    77  		m.t.Logf("Caller mismatch: %s != %s", l.Caller, r.Caller)
    78  		return false
    79  	}
    80  
    81  	if l.Service != r.Service {
    82  		m.t.Logf("Service mismatch: %s != %s", l.Service, r.Service)
    83  		return false
    84  	}
    85  
    86  	// We check if 'Transport' is set before comparing, since transports may
    87  	// modify the field after users define the struct.
    88  	if l.Transport != "" && l.Transport != r.Transport {
    89  		m.t.Logf("Transport mismatch: %s != %s", l.Transport, r.Transport)
    90  		return false
    91  	}
    92  
    93  	if l.Encoding != r.Encoding {
    94  		m.t.Logf("Encoding mismatch: %s != %s", l.Service, r.Service)
    95  		return false
    96  	}
    97  
    98  	if l.Procedure != r.Procedure {
    99  		m.t.Logf("Procedure mismatch: %s != %s", l.Procedure, r.Procedure)
   100  		return false
   101  	}
   102  
   103  	if l.ShardKey != r.ShardKey {
   104  		m.t.Logf("Shard Key mismatch: %s != %s", l.ShardKey, r.ShardKey)
   105  		return false
   106  	}
   107  
   108  	if l.RoutingKey != r.RoutingKey {
   109  		m.t.Logf("Routing Key mismatch: %s != %s", l.RoutingKey, r.RoutingKey)
   110  		return false
   111  	}
   112  
   113  	if l.RoutingDelegate != r.RoutingDelegate {
   114  		m.t.Logf("Routing Delegate mismatch: %s != %s", l.RoutingDelegate, r.RoutingDelegate)
   115  		return false
   116  	}
   117  
   118  	// len check to handle nil vs empty cases gracefully.
   119  	if l.Headers.Len() != r.Headers.Len() {
   120  		if !reflect.DeepEqual(l.Headers, r.Headers) {
   121  			m.t.Logf("Headers did not match:\n\t   %v\n\t!= %v", l.Headers, r.Headers)
   122  			return false
   123  		}
   124  	}
   125  
   126  	rbody, err := ioutil.ReadAll(r.Body)
   127  	if err != nil {
   128  		m.t.Fatalf("failed to read body: %v", err)
   129  	}
   130  	r.Body = bytes.NewReader(rbody) // in case it is reused
   131  
   132  	if !bytes.Equal(m.body, rbody) {
   133  		m.t.Logf("Body mismatch: %v != %v", m.body, rbody)
   134  		return false
   135  	}
   136  
   137  	return true
   138  }
   139  
   140  func (m RequestMatcher) String() string {
   141  	return fmt.Sprintf("matches request %v with body %v", m.req, m.body)
   142  }
   143  
   144  // checkSuperSet checks if the items in l are all also present in r.
   145  func checkSuperSet(l, r transport.Headers) error {
   146  	missing := make([]string, 0, l.Len())
   147  	for k, vl := range l.Items() {
   148  		vr, ok := r.Get(k)
   149  		if !ok || vr != vl {
   150  			missing = append(missing, k)
   151  		}
   152  	}
   153  
   154  	if len(missing) > 0 {
   155  		return fmt.Errorf("missing headers: %v", strings.Join(missing, ", "))
   156  	}
   157  	return nil
   158  }
   159  
   160  // ResponseMatcher is similar to RequestMatcher but for responses.
   161  type ResponseMatcher struct {
   162  	t    *testing.T
   163  	res  *transport.Response
   164  	body []byte
   165  }
   166  
   167  // NewResponseMatcher builds a new ResponseMatcher that verifies that
   168  // responses match the given Response.
   169  func NewResponseMatcher(t *testing.T, r *transport.Response) ResponseMatcher {
   170  	body, err := ioutil.ReadAll(r.Body)
   171  	defer r.Body.Close()
   172  	if err != nil {
   173  		t.Fatalf("failed to read response body: %v", err)
   174  	}
   175  
   176  	// restore a copy of the body so that the caller can still use the
   177  	// response object
   178  	r.Body = ioutil.NopCloser(bytes.NewReader(body))
   179  	return ResponseMatcher{t: t, res: r, body: body}
   180  }
   181  
   182  // Matches checks if the given object matches the Response provided in
   183  // NewResponseMatcher.
   184  func (m ResponseMatcher) Matches(got interface{}) bool {
   185  	l := m.res
   186  	r, ok := got.(*transport.Response)
   187  	if !ok {
   188  		panic(fmt.Sprintf("expected *transport.Response, got %v", got))
   189  	}
   190  
   191  	if err := checkSuperSet(l.Headers, r.Headers); err != nil {
   192  		m.t.Logf("Headers mismatch: %v != %v\n\t%v", l.Headers, r.Headers, err)
   193  		return false
   194  	}
   195  
   196  	rbody, err := ioutil.ReadAll(r.Body)
   197  	if err != nil {
   198  		m.t.Fatalf("failed to read body: %v", err)
   199  	}
   200  	r.Body = ioutil.NopCloser(bytes.NewReader(rbody)) // in case it is reused
   201  
   202  	if !bytes.Equal(m.body, rbody) {
   203  		m.t.Logf("Body mismatch: %v != %v", m.body, rbody)
   204  		return false
   205  	}
   206  
   207  	return true
   208  }
   209  
   210  // FakeResponseWriter is a ResponseWriter that records the headers and the body
   211  // written to it.
   212  type FakeResponseWriter struct {
   213  	IsApplicationError   bool
   214  	ApplicationErrorMeta *transport.ApplicationErrorMeta
   215  	Headers              transport.Headers
   216  	Body                 bytes.Buffer
   217  }
   218  
   219  // SetApplicationError for FakeResponseWriter.
   220  func (fw *FakeResponseWriter) SetApplicationError() {
   221  	fw.IsApplicationError = true
   222  }
   223  
   224  // AddHeaders for FakeResponseWriter.
   225  func (fw *FakeResponseWriter) AddHeaders(h transport.Headers) {
   226  	for k, v := range h.OriginalItems() {
   227  		fw.Headers = fw.Headers.With(k, v)
   228  	}
   229  }
   230  
   231  // Write for FakeResponseWriter.
   232  func (fw *FakeResponseWriter) Write(s []byte) (int, error) {
   233  	return fw.Body.Write(s)
   234  }
   235  
   236  // SetApplicationErrorMeta for FakeResponseWriter
   237  func (fw *FakeResponseWriter) SetApplicationErrorMeta(applicationErrorMeta *transport.ApplicationErrorMeta) {
   238  	fw.ApplicationErrorMeta = applicationErrorMeta
   239  }