go.uber.org/yarpc@v1.72.1/yarpctest/recorder/recorder.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  // Package recorder records & replay yarpc requests on the client side.
    22  //
    23  // For recording, the client must be connected and able to issue requests to a
    24  // remote service. Every request and its response is recorded into a YAML file,
    25  // under the directory "testdata/recordings" relative to the test directory.
    26  //
    27  // During replay, the client doesn't need to be connected, for any recorded
    28  // request Recorder will return the recorded response. Any new request (ie: not
    29  // pre-recorded) will abort the test.
    30  //
    31  // NewRecorder() returns a Recorder, in the mode specified by the flag
    32  // `--recorder=replay|append|overwrite`. `replay` is the default.
    33  //
    34  // The new Recorder instance is a yarpc outbound middleware. It takes a
    35  // `testing.T` or compatible as argument.
    36  //
    37  // Example:
    38  //  func MyTest(t *testing.T) {
    39  //    dispatcher := yarpc.NewDispatcher(yarpc.Config{
    40  //    	Name: "...",
    41  //    	Outbounds: transport.Outbounds{
    42  //    		...
    43  //    	},
    44  //      OutboundMiddleware: yarpc.OutboundMiddleware {
    45  //    	  Unary: recorder.NewRecorder(t),
    46  //      },
    47  //    })
    48  //  }
    49  //
    50  // Running the tests in append mode:
    51  //  $ go test -v ./... --recorder=append
    52  //
    53  // The recorded messages will be stored in
    54  // `./testdata/recordings/*.yaml`.
    55  package recorder
    56  
    57  import (
    58  	"bytes"
    59  	"context"
    60  	"encoding/base64"
    61  	"flag"
    62  	"fmt"
    63  	"hash/fnv"
    64  	"io/ioutil"
    65  	"os"
    66  	"path/filepath"
    67  	"sort"
    68  	"strings"
    69  	"unicode"
    70  
    71  	"go.uber.org/yarpc/api/transport"
    72  	"gopkg.in/yaml.v2"
    73  )
    74  
    75  var recorderFlag = flag.String("recorder", "replay",
    76  	`replay: replay from recorded request/response pairs.
    77  overwrite: record all request/response pairs, overwriting records.
    78  append: replay existing and record new request/response pairs.`)
    79  
    80  // Recorder records & replay yarpc requests on the client side.
    81  //
    82  // For recording, the client must be connected and able to issue requests to a
    83  // remote service. Every request and its response is recorded into a YAML file,
    84  // under the directory "testdata/recordings".
    85  //
    86  // During replay, the client doesn't need to be connected, for any recorded
    87  // request Recorder will return the recorded response. Any new request will
    88  // abort the test by calling logger.Fatal().
    89  type Recorder struct {
    90  	mode       Mode
    91  	logger     TestingT
    92  	recordsDir string
    93  }
    94  
    95  const defaultRecorderDir = "testdata/recordings"
    96  const recordComment = `# In order to update this recording, setup your external dependencies and run
    97  # ` + "`" + `go test <insert test files here> --recorder=replay|append|overwrite` + "`\n"
    98  const currentRecordVersion = 1
    99  
   100  // Mode is the recording mode of the recorder.
   101  type Mode int
   102  
   103  const (
   104  	// invalidMode is private and used to represent invalid modes.
   105  	invalidMode Mode = iota
   106  
   107  	// Replay replays stored request/response pairs, any non pre-recorded
   108  	// requests will be rejected.
   109  	Replay
   110  
   111  	// Overwrite will store all request/response pairs, overwriting existing
   112  	// records.
   113  	Overwrite
   114  
   115  	// Append will store all new request/response pairs and replay from
   116  	// existing record.
   117  	Append
   118  )
   119  
   120  func (m Mode) toHumanString() string {
   121  	switch m {
   122  	case Replay:
   123  		return "replaying"
   124  	case Overwrite:
   125  		return "recording (overwrite)"
   126  	case Append:
   127  		return "recording (append)"
   128  	default:
   129  		return fmt.Sprintf("Mode(%d)", int(m))
   130  	}
   131  }
   132  
   133  // modeFromString converts an English string of a mode to a `Mode`.
   134  func modeFromString(s string) (Mode, error) {
   135  	switch s {
   136  	case "replay":
   137  		return Replay, nil
   138  	case "overwrite":
   139  		return Overwrite, nil
   140  	case "append":
   141  		return Append, nil
   142  	}
   143  	return invalidMode, fmt.Errorf(`invalid mode: "%s"`, s)
   144  }
   145  
   146  // TestingT is an interface used by the recorder for logging and reporting fatal
   147  // errors. It is intentionally made to match with testing.T.
   148  type TestingT interface {
   149  	// Logf must behaves similarly to testing.T.Logf.
   150  	Logf(format string, args ...interface{})
   151  
   152  	// Fatal should behaves similarly to testing.T.Fatal. Namely, it must abort
   153  	// the current test.
   154  	Fatal(args ...interface{})
   155  }
   156  
   157  // NewRecorder returns a Recorder in whatever mode specified via the
   158  // `--recorder` flag.
   159  //
   160  // The new Recorder instance is a yarpc unary outbound middleware. It takes a
   161  // logger as argument compatible with `testing.T`.
   162  //
   163  // See package documentation for more details.
   164  func NewRecorder(logger TestingT, opts ...Option) *Recorder {
   165  	cwd, err := os.Getwd()
   166  	if err != nil {
   167  		logger.Fatal(err)
   168  	}
   169  	recorder := &Recorder{
   170  		logger: logger,
   171  	}
   172  
   173  	var cfg config
   174  	for _, opt := range opts {
   175  		opt(&cfg)
   176  	}
   177  
   178  	if cfg.RecordsPath != "" {
   179  		recorder.recordsDir = cfg.RecordsPath
   180  	} else {
   181  		recorder.recordsDir = filepath.Join(cwd, defaultRecorderDir)
   182  	}
   183  
   184  	mode := cfg.Mode
   185  	if mode == invalidMode {
   186  		mode, err = modeFromString(*recorderFlag)
   187  		if err != nil {
   188  			logger.Fatal(err)
   189  		}
   190  	}
   191  	recorder.SetMode(mode)
   192  	return recorder
   193  }
   194  
   195  // RecordMode sets the mode.
   196  func RecordMode(mode Mode) Option {
   197  	return func(cfg *config) {
   198  		cfg.Mode = mode
   199  	}
   200  }
   201  
   202  // RecordsPath sets the records directory path.
   203  func RecordsPath(path string) Option {
   204  	return func(cfg *config) {
   205  		cfg.RecordsPath = path
   206  	}
   207  }
   208  
   209  // Option is the type used for the functional options pattern.
   210  type Option func(*config)
   211  
   212  type config struct {
   213  	Mode        Mode
   214  	RecordsPath string
   215  }
   216  
   217  // SetMode let you choose enable the different replay and recording modes,
   218  // overriding the --recorder flag.
   219  func (r *Recorder) SetMode(newMode Mode) {
   220  	if r.mode == newMode {
   221  		return
   222  	}
   223  	r.mode = newMode
   224  	r.logger.Logf("recorder %s from/to %v", r.mode.toHumanString(), r.recordsDir)
   225  }
   226  
   227  func sanitizeFilename(s string) (r string) {
   228  	const allowedRunes = `_-.`
   229  	return strings.Map(func(rv rune) rune {
   230  		if unicode.IsLetter(rv) || unicode.IsNumber(rv) {
   231  			return rv
   232  		}
   233  		if strings.ContainsRune(allowedRunes, rv) {
   234  			return rv
   235  		}
   236  		return '_'
   237  	}, s)
   238  }
   239  
   240  func (r *Recorder) hashRequestRecord(requestRecord *requestRecord) string {
   241  	log := r.logger
   242  	hash := fnv.New64a()
   243  
   244  	ha := func(b string) {
   245  		_, err := hash.Write([]byte(b))
   246  		if err != nil {
   247  			log.Fatal(err)
   248  		}
   249  		_, err = hash.Write([]byte("."))
   250  		if err != nil {
   251  			log.Fatal(err)
   252  		}
   253  	}
   254  
   255  	ha(requestRecord.Caller)
   256  	ha(requestRecord.Service)
   257  	ha(string(requestRecord.Encoding))
   258  	ha(requestRecord.Procedure)
   259  
   260  	orderedHeadersKeys := make([]string, 0, len(requestRecord.Headers))
   261  	for k := range requestRecord.Headers {
   262  		orderedHeadersKeys = append(orderedHeadersKeys, k)
   263  	}
   264  	sort.Strings(orderedHeadersKeys)
   265  	for _, k := range orderedHeadersKeys {
   266  		ha(k)
   267  		ha(requestRecord.Headers[k])
   268  	}
   269  
   270  	ha(requestRecord.ShardKey)
   271  	ha(requestRecord.RoutingKey)
   272  	ha(requestRecord.RoutingDelegate)
   273  
   274  	_, err := hash.Write(requestRecord.Body)
   275  	if err != nil {
   276  		log.Fatal(err)
   277  	}
   278  	return fmt.Sprintf("%x", hash.Sum64())
   279  }
   280  
   281  func (r *Recorder) makeFilePath(request *transport.Request, hash string) string {
   282  	s := fmt.Sprintf("%s.%s.%s.yaml", request.Service, request.Procedure, hash)
   283  	return filepath.Join(r.recordsDir, sanitizeFilename(s))
   284  }
   285  
   286  // Call implements the yarpc transport outbound middleware interface
   287  func (r *Recorder) Call(
   288  	ctx context.Context,
   289  	request *transport.Request,
   290  	out transport.UnaryOutbound) (*transport.Response, error) {
   291  	log := r.logger
   292  
   293  	requestRecord := r.requestToRequestRecord(request)
   294  
   295  	requestHash := r.hashRequestRecord(&requestRecord)
   296  	filepath := r.makeFilePath(request, requestHash)
   297  
   298  	switch r.mode {
   299  	case Replay:
   300  		cachedRecord, err := r.loadRecord(filepath)
   301  		if err != nil {
   302  			log.Fatal(err)
   303  		}
   304  		response := r.recordToResponse(cachedRecord)
   305  		return &response, nil
   306  	case Append:
   307  		cachedRecord, err := r.loadRecord(filepath)
   308  		if err == nil {
   309  			response := r.recordToResponse(cachedRecord)
   310  			return &response, nil
   311  		}
   312  		fallthrough
   313  	case Overwrite:
   314  		response, err := out.Call(ctx, request)
   315  		if err == nil {
   316  			cachedRecord := record{
   317  				Version:  currentRecordVersion,
   318  				Request:  requestRecord,
   319  				Response: r.responseToResponseRecord(response),
   320  			}
   321  			r.saveRecord(filepath, &cachedRecord)
   322  		}
   323  		return response, err
   324  	default:
   325  		panic(fmt.Sprintf("invalid record mode: %v", r.mode))
   326  	}
   327  }
   328  
   329  func (r *Recorder) recordToResponse(cachedRecord *record) transport.Response {
   330  	response := transport.Response{
   331  		Headers: transport.HeadersFromMap(cachedRecord.Response.Headers),
   332  		Body:    ioutil.NopCloser(bytes.NewReader(cachedRecord.Response.Body)),
   333  	}
   334  	return response
   335  }
   336  
   337  func (r *Recorder) requestToRequestRecord(request *transport.Request) requestRecord {
   338  	requestBody, err := ioutil.ReadAll(request.Body)
   339  	if err != nil {
   340  		r.logger.Fatal(err)
   341  	}
   342  	request.Body = ioutil.NopCloser(bytes.NewReader(requestBody))
   343  	return requestRecord{
   344  		Caller:          request.Caller,
   345  		Service:         request.Service,
   346  		Procedure:       request.Procedure,
   347  		Encoding:        string(request.Encoding),
   348  		Headers:         request.Headers.Items(),
   349  		ShardKey:        request.ShardKey,
   350  		RoutingKey:      request.RoutingKey,
   351  		RoutingDelegate: request.RoutingDelegate,
   352  		Body:            requestBody,
   353  	}
   354  }
   355  
   356  func (r *Recorder) responseToResponseRecord(response *transport.Response) responseRecord {
   357  	responseBody, err := ioutil.ReadAll(response.Body)
   358  	if err != nil {
   359  		r.logger.Fatal(err)
   360  	}
   361  	response.Body = ioutil.NopCloser(bytes.NewReader(responseBody))
   362  	return responseRecord{
   363  		Headers: response.Headers.Items(),
   364  		Body:    responseBody,
   365  	}
   366  }
   367  
   368  // loadRecord attempts to load a record from the given file. If the record
   369  // cannot be found the errRecordNotFound is returned. Any other error will
   370  // abort the current test.
   371  func (r *Recorder) loadRecord(filepath string) (*record, error) {
   372  	rawRecord, err := ioutil.ReadFile(filepath)
   373  	if err != nil {
   374  		if os.IsNotExist(err) {
   375  			return nil, newErrRecordNotFound(err)
   376  		}
   377  		r.logger.Fatal(err)
   378  	}
   379  	var cachedRecord record
   380  	if err := yaml.Unmarshal(rawRecord, &cachedRecord); err != nil {
   381  		r.logger.Fatal(err)
   382  	}
   383  
   384  	if cachedRecord.Version != currentRecordVersion {
   385  		r.logger.Fatal(fmt.Sprintf("unsupported record version %d (expected %d)",
   386  			cachedRecord.Version, currentRecordVersion))
   387  	}
   388  
   389  	return &cachedRecord, nil
   390  }
   391  
   392  // saveRecord attempts to save a record to the given file, any error fails the
   393  // current test.
   394  func (r *Recorder) saveRecord(filepath string, cachedRecord *record) {
   395  	if err := os.MkdirAll(defaultRecorderDir, 0775); err != nil {
   396  		r.logger.Fatal(err)
   397  	}
   398  
   399  	rawRecord, err := yaml.Marshal(&cachedRecord)
   400  	if err != nil {
   401  		r.logger.Fatal(err)
   402  	}
   403  
   404  	file, err := os.OpenFile(filepath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0664)
   405  	if err != nil {
   406  		r.logger.Fatal(err)
   407  	}
   408  
   409  	if _, err := file.Write([]byte(recordComment)); err != nil {
   410  		r.logger.Fatal(err)
   411  	}
   412  
   413  	if _, err := file.Write(rawRecord); err != nil {
   414  		r.logger.Fatal(err)
   415  	}
   416  }
   417  
   418  type errRecordNotFound struct {
   419  	underlyingError error
   420  }
   421  
   422  func newErrRecordNotFound(underlyingError error) errRecordNotFound {
   423  	return errRecordNotFound{underlyingError}
   424  }
   425  
   426  func (e errRecordNotFound) Error() string {
   427  	return fmt.Sprintf("record not found (%s)", e.underlyingError)
   428  }
   429  
   430  type requestRecord struct {
   431  	Caller          string
   432  	Service         string
   433  	Procedure       string
   434  	Encoding        string
   435  	Headers         map[string]string
   436  	ShardKey        string
   437  	RoutingKey      string
   438  	RoutingDelegate string
   439  	Body            base64blob
   440  }
   441  
   442  type responseRecord struct {
   443  	Headers map[string]string
   444  	Body    base64blob
   445  }
   446  
   447  type record struct {
   448  	Version  uint
   449  	Request  requestRecord
   450  	Response responseRecord
   451  }
   452  
   453  type base64blob []byte
   454  
   455  func (b base64blob) MarshalYAML() (interface{}, error) {
   456  	return base64.StdEncoding.EncodeToString(b), nil
   457  }
   458  
   459  func (b *base64blob) UnmarshalYAML(unmarshal func(interface{}) error) error {
   460  	var base64encoded string
   461  	if err := unmarshal(&base64encoded); err != nil {
   462  		return err
   463  	}
   464  	decoded, err := base64.StdEncoding.DecodeString(base64encoded)
   465  	if err != nil {
   466  		return err
   467  	}
   468  	*b = decoded
   469  	return nil
   470  }