go.uber.org/yarpc@v1.72.1/yarpctest/recorder/recorder_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package recorder
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"fmt"
    27  	"io/ioutil"
    28  	"math/rand"
    29  	"os"
    30  	"path"
    31  	"testing"
    32  	"time"
    33  
    34  	"github.com/stretchr/testify/assert"
    35  	"github.com/stretchr/testify/require"
    36  	"go.uber.org/yarpc"
    37  	"go.uber.org/yarpc/api/transport"
    38  	"go.uber.org/yarpc/encoding/raw"
    39  	"go.uber.org/yarpc/internal/yarpctest"
    40  	"go.uber.org/yarpc/transport/http"
    41  )
    42  
    43  func TestSanitizeFilename(t *testing.T) {
    44  	assert.EqualValues(t, sanitizeFilename(`hello`), `hello`)
    45  	assert.EqualValues(t, sanitizeFilename(`h/e\l?l%o*`), `h_e_l_l_o_`)
    46  	assert.EqualValues(t, sanitizeFilename(`:h|e"l<l>o.`), `_h_e_l_l_o.`)
    47  	assert.EqualValues(t, sanitizeFilename(`10€|çí¹`), `10__çí¹`)
    48  	assert.EqualValues(t, sanitizeFilename("hel\x00lo"), `hel_lo`)
    49  }
    50  
    51  type randomGenerator struct {
    52  	randsrc *rand.Rand
    53  }
    54  
    55  func newRandomGenerator(seed int64) randomGenerator {
    56  	return randomGenerator{
    57  		randsrc: rand.New(rand.NewSource(seed)),
    58  	}
    59  }
    60  
    61  // Atom returns an ASCII string.
    62  func (r *randomGenerator) Atom() string {
    63  	length := 3 + r.randsrc.Intn(13)
    64  	atom := make([]byte, length)
    65  	for i := 0; i < length; i++ {
    66  		letter := r.randsrc.Intn(2 * 26)
    67  		if letter < 26 {
    68  			atom[i] = 'A' + byte(letter)
    69  		} else {
    70  			atom[i] = 'a' + byte(letter-26)
    71  		}
    72  	}
    73  	return string(atom)
    74  }
    75  
    76  // Headers returns a new randomized header.
    77  func (r *randomGenerator) Headers() transport.Headers {
    78  	headers := transport.NewHeaders()
    79  	size := 2 + r.randsrc.Intn(6)
    80  	for i := 0; i < size; i++ {
    81  		headers = headers.With(r.Atom(), r.Atom())
    82  	}
    83  	return headers
    84  }
    85  
    86  // Request returns a new randomized request.
    87  func (r *randomGenerator) Request() transport.Request {
    88  	bodyData := []byte(r.Atom())
    89  
    90  	return transport.Request{
    91  		Caller:          r.Atom(),
    92  		Service:         r.Atom(),
    93  		Encoding:        transport.Encoding(r.Atom()),
    94  		Procedure:       r.Atom(),
    95  		Headers:         r.Headers(),
    96  		ShardKey:        r.Atom(),
    97  		RoutingKey:      r.Atom(),
    98  		RoutingDelegate: r.Atom(),
    99  		Body:            ioutil.NopCloser(bytes.NewReader(bodyData)),
   100  	}
   101  }
   102  
   103  func TestHash(t *testing.T) {
   104  	rgen := newRandomGenerator(42)
   105  	request := rgen.Request()
   106  
   107  	recorder := NewRecorder(t)
   108  	requestRecord := recorder.requestToRequestRecord(&request)
   109  	referenceHash := recorder.hashRequestRecord(&requestRecord)
   110  
   111  	require.Equal(t, "7195d5a712201d2a", referenceHash)
   112  
   113  	// Caller
   114  	r := request
   115  	r.Caller = rgen.Atom()
   116  	requestRecord = recorder.requestToRequestRecord(&r)
   117  	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)
   118  
   119  	// Service
   120  	r = request
   121  	r.Service = rgen.Atom()
   122  	requestRecord = recorder.requestToRequestRecord(&r)
   123  	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)
   124  
   125  	// Encoding
   126  	r = request
   127  	r.Encoding = transport.Encoding(rgen.Atom())
   128  	requestRecord = recorder.requestToRequestRecord(&r)
   129  	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)
   130  
   131  	// Procedure
   132  	r = request
   133  	r.Procedure = rgen.Atom()
   134  	requestRecord = recorder.requestToRequestRecord(&r)
   135  	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)
   136  
   137  	// Headers
   138  	r = request
   139  	r.Headers = rgen.Headers()
   140  	requestRecord = recorder.requestToRequestRecord(&r)
   141  	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)
   142  
   143  	// ShardKey
   144  	r = request
   145  	r.ShardKey = rgen.Atom()
   146  	requestRecord = recorder.requestToRequestRecord(&r)
   147  	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)
   148  
   149  	// RoutingKey
   150  	r = request
   151  	r.RoutingKey = rgen.Atom()
   152  	requestRecord = recorder.requestToRequestRecord(&r)
   153  	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)
   154  
   155  	// RoutingDelegate
   156  	r = request
   157  	r.RoutingDelegate = rgen.Atom()
   158  	requestRecord = recorder.requestToRequestRecord(&r)
   159  	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)
   160  
   161  	// Body
   162  	r = request
   163  	request.Body = ioutil.NopCloser(bytes.NewReader([]byte(rgen.Atom())))
   164  	requestRecord = recorder.requestToRequestRecord(&r)
   165  	assert.NotEqual(t, recorder.hashRequestRecord(&requestRecord), referenceHash)
   166  }
   167  
   168  var testingTMockFatal = struct{}{}
   169  
   170  type testingTMock struct {
   171  	*testing.T
   172  
   173  	fatalCount int
   174  }
   175  
   176  func (t *testingTMock) Fatal(args ...interface{}) {
   177  	t.Logf("counting fatal: %s", args)
   178  	t.fatalCount++
   179  	panic(testingTMockFatal)
   180  }
   181  
   182  func withDisconnectedClient(t *testing.T, recorder *Recorder, f func(raw.Client)) {
   183  	httpTransport := http.NewTransport()
   184  
   185  	clientDisp := yarpc.NewDispatcher(yarpc.Config{
   186  		Name: "client",
   187  		Outbounds: yarpc.Outbounds{
   188  			"server": {
   189  				Unary: httpTransport.NewSingleOutbound("http://127.0.0.1:65535"),
   190  			},
   191  		},
   192  		OutboundMiddleware: yarpc.OutboundMiddleware{
   193  			Unary: recorder,
   194  		},
   195  	})
   196  	require.NoError(t, clientDisp.Start())
   197  	defer clientDisp.Stop()
   198  
   199  	client := raw.New(clientDisp.ClientConfig("server"))
   200  	f(client)
   201  }
   202  
   203  func withConnectedClient(t *testing.T, recorder *Recorder, f func(raw.Client)) {
   204  	httpTransport := http.NewTransport()
   205  	serverHTTP := httpTransport.NewInbound("127.0.0.1:0")
   206  	serverDisp := yarpc.NewDispatcher(yarpc.Config{
   207  		Name:     "server",
   208  		Inbounds: yarpc.Inbounds{serverHTTP},
   209  	})
   210  
   211  	serverDisp.Register(raw.Procedure("hello",
   212  		func(ctx context.Context, body []byte) ([]byte, error) {
   213  			return append(body, []byte(", World")...), nil
   214  		}))
   215  
   216  	require.NoError(t, serverDisp.Start())
   217  	defer serverDisp.Stop()
   218  
   219  	clientDisp := yarpc.NewDispatcher(yarpc.Config{
   220  		Name: "client",
   221  		Outbounds: yarpc.Outbounds{
   222  			"server": {
   223  				Unary: httpTransport.NewSingleOutbound(fmt.Sprintf("http://%s", yarpctest.ZeroAddrToHostPort(serverHTTP.Addr()))),
   224  			},
   225  		},
   226  		OutboundMiddleware: yarpc.OutboundMiddleware{
   227  			Unary: recorder,
   228  		},
   229  	})
   230  	require.NoError(t, clientDisp.Start())
   231  	defer clientDisp.Stop()
   232  
   233  	client := raw.New(clientDisp.ClientConfig("server"))
   234  	f(client)
   235  }
   236  
   237  func TestEndToEnd(t *testing.T) {
   238  	tMock := testingTMock{t, 0}
   239  
   240  	dir, err := ioutil.TempDir("", "yarpcgorecorder")
   241  	if err != nil {
   242  		t.Fatal(err)
   243  	}
   244  	defer os.RemoveAll(dir) // clean up
   245  
   246  	// First we double check that our cache is empty.
   247  	recorder := NewRecorder(&tMock, RecordMode(Replay), RecordsPath(dir))
   248  
   249  	withDisconnectedClient(t, recorder, func(client raw.Client) {
   250  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   251  		defer cancel()
   252  
   253  		require.Panics(t, func() {
   254  			client.Call(ctx, "hello", []byte("Hello"))
   255  		})
   256  		assert.Equal(t, tMock.fatalCount, 1)
   257  	})
   258  
   259  	// Now let's record our call.
   260  	recorder = NewRecorder(&tMock, RecordMode(Overwrite), RecordsPath(dir))
   261  
   262  	withConnectedClient(t, recorder, func(client raw.Client) {
   263  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   264  		defer cancel()
   265  
   266  		rbody, err := client.Call(ctx, "hello", []byte("Hello"))
   267  		require.NoError(t, err)
   268  		assert.Equal(t, rbody, []byte("Hello, World"))
   269  	})
   270  
   271  	// Now replay the call.
   272  	recorder = NewRecorder(&tMock, RecordMode(Replay), RecordsPath(dir))
   273  
   274  	withDisconnectedClient(t, recorder, func(client raw.Client) {
   275  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   276  		defer cancel()
   277  
   278  		rbody, err := client.Call(ctx, "hello", []byte("Hello"))
   279  		require.NoError(t, err)
   280  		assert.Equal(t, rbody, []byte("Hello, World"))
   281  	})
   282  }
   283  
   284  func TestEmptyReplay(t *testing.T) {
   285  	tMock := testingTMock{t, 0}
   286  
   287  	dir, err := ioutil.TempDir("", "yarpcgorecorder")
   288  	if err != nil {
   289  		t.Fatal(err)
   290  	}
   291  	defer os.RemoveAll(dir) // clean up
   292  
   293  	recorder := NewRecorder(&tMock, RecordMode(Replay), RecordsPath(dir))
   294  
   295  	withDisconnectedClient(t, recorder, func(client raw.Client) {
   296  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   297  		defer cancel()
   298  
   299  		require.Panics(t, func() {
   300  			client.Call(ctx, "hello", []byte("Hello"))
   301  		})
   302  		assert.Equal(t, tMock.fatalCount, 1)
   303  	})
   304  }
   305  
   306  const refRecordFilename = `server.hello.254fa3bab61fc27f.yaml`
   307  const refRecordContent = recordComment +
   308  	`version: 1
   309  request:
   310    caller: client
   311    service: server
   312    procedure: hello
   313    encoding: raw
   314    headers: {}
   315    shardkey: ""
   316    routingkey: ""
   317    routingdelegate: ""
   318    body: SGVsbG8=
   319  response:
   320    headers: {}
   321    body: SGVsbG8sIFdvcmxk
   322  `
   323  
   324  func TestRecording(t *testing.T) {
   325  	tMock := testingTMock{t, 0}
   326  
   327  	dir, err := ioutil.TempDir("", "yarpcgorecorder")
   328  	if err != nil {
   329  		t.Fatal(err)
   330  	}
   331  	defer os.RemoveAll(dir) // clean up
   332  
   333  	recorder := NewRecorder(&tMock, RecordMode(Append), RecordsPath(dir))
   334  
   335  	withConnectedClient(t, recorder, func(client raw.Client) {
   336  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   337  		defer cancel()
   338  
   339  		rbody, err := client.Call(ctx, "hello", []byte("Hello"))
   340  		require.NoError(t, err)
   341  		assert.Equal(t, []byte("Hello, World"), rbody)
   342  	})
   343  
   344  	recordPath := path.Join(dir, refRecordFilename)
   345  	_, err = os.Stat(recordPath)
   346  	require.NoError(t, err)
   347  
   348  	recordContent, err := ioutil.ReadFile(recordPath)
   349  	require.NoError(t, err)
   350  	assert.Equal(t, refRecordContent, string(recordContent))
   351  }
   352  
   353  func TestReplaying(t *testing.T) {
   354  	tMock := testingTMock{t, 0}
   355  
   356  	dir, err := ioutil.TempDir("", "yarpcgorecorder")
   357  	if err != nil {
   358  		t.Fatal(err)
   359  	}
   360  	defer os.RemoveAll(dir) // clean up
   361  
   362  	recorder := NewRecorder(&tMock, RecordMode(Replay), RecordsPath(dir))
   363  
   364  	recordPath := path.Join(dir, refRecordFilename)
   365  	err = ioutil.WriteFile(recordPath, []byte(refRecordContent), 0444)
   366  	require.NoError(t, err)
   367  
   368  	withDisconnectedClient(t, recorder, func(client raw.Client) {
   369  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   370  		defer cancel()
   371  
   372  		rbody, err := client.Call(ctx, "hello", []byte("Hello"))
   373  		require.NoError(t, err)
   374  		assert.Equal(t, rbody, []byte("Hello, World"))
   375  	})
   376  }