go.uber.org/yarpc@v1.72.1/internal/observability/middleware_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package observability
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"errors"
    27  	"fmt"
    28  	"io"
    29  	"strings"
    30  	"sync"
    31  	"testing"
    32  	"time"
    33  
    34  	"github.com/stretchr/testify/assert"
    35  	"github.com/stretchr/testify/require"
    36  	"go.uber.org/net/metrics"
    37  	"go.uber.org/yarpc/api/transport"
    38  	"go.uber.org/yarpc/api/transport/transporttest"
    39  	"go.uber.org/yarpc/internal/bufferpool"
    40  	"go.uber.org/yarpc/internal/digester"
    41  	"go.uber.org/yarpc/yarpcerrors"
    42  	"go.uber.org/zap"
    43  	"go.uber.org/zap/zapcore"
    44  	"go.uber.org/zap/zaptest/observer"
    45  )
    46  
    47  func TestNewMiddlewareLogLevels(t *testing.T) {
    48  	// It's a bit unfortunate that we're asserting conditions about the
    49  	// internal state of Middleware and graph here but short of duplicating
    50  	// the other test, this is the cleanest option.
    51  
    52  	infoLevel := zapcore.InfoLevel
    53  	warnLevel := zapcore.WarnLevel
    54  
    55  	t.Run("Inbound", func(t *testing.T) {
    56  		t.Run("Success", func(t *testing.T) {
    57  			t.Run("default", func(t *testing.T) {
    58  				assert.Equal(t, zapcore.DebugLevel, NewMiddleware(Config{}).graph.inboundLevels.success)
    59  			})
    60  
    61  			t.Run("any direction override", func(t *testing.T) {
    62  				assert.Equal(t, zapcore.InfoLevel, NewMiddleware(Config{
    63  					Levels: LevelsConfig{
    64  						Default: DirectionalLevelsConfig{
    65  							Success: &infoLevel,
    66  						},
    67  					},
    68  				}).graph.inboundLevels.success)
    69  			})
    70  
    71  			t.Run("directional override", func(t *testing.T) {
    72  				assert.Equal(t, zapcore.InfoLevel, NewMiddleware(Config{
    73  					Levels: LevelsConfig{
    74  						Default: DirectionalLevelsConfig{
    75  							Success: &warnLevel, // overridden by Inbound.Success
    76  						},
    77  						Inbound: DirectionalLevelsConfig{
    78  							Success: &infoLevel, // overrides Default.Success
    79  						},
    80  					},
    81  				}).graph.inboundLevels.success)
    82  			})
    83  		})
    84  
    85  		t.Run("Failure", func(t *testing.T) {
    86  			t.Run("default", func(t *testing.T) {
    87  				m := NewMiddleware(Config{})
    88  				assert.Equal(t, zapcore.ErrorLevel, m.graph.inboundLevels.failure)
    89  				assert.False(t, m.graph.inboundLevels.useApplicationErrorFailureLevels)
    90  			})
    91  
    92  			t.Run("override", func(t *testing.T) {
    93  				m := NewMiddleware(Config{
    94  					Levels: LevelsConfig{
    95  						Inbound: DirectionalLevelsConfig{
    96  							Failure: &warnLevel,
    97  						},
    98  					},
    99  				})
   100  				assert.Equal(t, zapcore.WarnLevel, m.graph.inboundLevels.failure)
   101  				assert.True(t, m.graph.inboundLevels.useApplicationErrorFailureLevels)
   102  			})
   103  		})
   104  
   105  		t.Run("ApplicationError", func(t *testing.T) {
   106  			t.Run("default", func(t *testing.T) {
   107  				m := NewMiddleware(Config{})
   108  				assert.Equal(t, zapcore.ErrorLevel, m.graph.inboundLevels.applicationError)
   109  				assert.False(t, m.graph.inboundLevels.useApplicationErrorFailureLevels)
   110  			})
   111  
   112  			t.Run("override", func(t *testing.T) {
   113  				m := NewMiddleware(Config{
   114  					Levels: LevelsConfig{
   115  						Inbound: DirectionalLevelsConfig{
   116  							ApplicationError: &warnLevel,
   117  						},
   118  					},
   119  				})
   120  				assert.Equal(t, zapcore.WarnLevel, m.graph.inboundLevels.applicationError)
   121  				assert.True(t, m.graph.inboundLevels.useApplicationErrorFailureLevels)
   122  			})
   123  		})
   124  
   125  		t.Run("ClientError", func(t *testing.T) {
   126  			t.Run("default", func(t *testing.T) {
   127  				m := NewMiddleware(Config{})
   128  				assert.Equal(t, zapcore.ErrorLevel, m.graph.inboundLevels.clientError)
   129  				assert.False(t, m.graph.inboundLevels.useApplicationErrorFailureLevels)
   130  			})
   131  
   132  			t.Run("override", func(t *testing.T) {
   133  				m := NewMiddleware(Config{
   134  					Levels: LevelsConfig{
   135  						Inbound: DirectionalLevelsConfig{
   136  							ClientError: &warnLevel,
   137  						},
   138  					},
   139  				})
   140  				assert.Equal(t, zapcore.WarnLevel, m.graph.inboundLevels.clientError)
   141  				assert.False(t, m.graph.inboundLevels.useApplicationErrorFailureLevels)
   142  			})
   143  		})
   144  
   145  		t.Run("serverError", func(t *testing.T) {
   146  			t.Run("default", func(t *testing.T) {
   147  				m := NewMiddleware(Config{})
   148  				assert.Equal(t, zapcore.ErrorLevel, m.graph.inboundLevels.serverError)
   149  				assert.False(t, m.graph.inboundLevels.useApplicationErrorFailureLevels)
   150  			})
   151  
   152  			t.Run("override", func(t *testing.T) {
   153  				m := NewMiddleware(Config{
   154  					Levels: LevelsConfig{
   155  						Inbound: DirectionalLevelsConfig{
   156  							ServerError: &warnLevel,
   157  						},
   158  					},
   159  				})
   160  				assert.Equal(t, zapcore.WarnLevel, m.graph.inboundLevels.serverError)
   161  				assert.False(t, m.graph.inboundLevels.useApplicationErrorFailureLevels)
   162  			})
   163  		})
   164  	})
   165  
   166  	t.Run("Outbound", func(t *testing.T) {
   167  		t.Run("Success", func(t *testing.T) {
   168  			t.Run("default", func(t *testing.T) {
   169  				m := NewMiddleware(Config{})
   170  				assert.Equal(t, zapcore.DebugLevel, m.graph.outboundLevels.success)
   171  				assert.False(t, m.graph.outboundLevels.useApplicationErrorFailureLevels)
   172  			})
   173  
   174  			t.Run("override", func(t *testing.T) {
   175  				m := NewMiddleware(Config{
   176  					Levels: LevelsConfig{
   177  						Outbound: DirectionalLevelsConfig{
   178  							Success: &infoLevel,
   179  						},
   180  					},
   181  				})
   182  				assert.Equal(t, zapcore.InfoLevel, m.graph.outboundLevels.success)
   183  				assert.False(t, m.graph.outboundLevels.useApplicationErrorFailureLevels)
   184  			})
   185  		})
   186  
   187  		t.Run("Failure", func(t *testing.T) {
   188  			t.Run("default", func(t *testing.T) {
   189  				m := NewMiddleware(Config{})
   190  				assert.Equal(t, zapcore.ErrorLevel, m.graph.outboundLevels.failure)
   191  				assert.False(t, m.graph.outboundLevels.useApplicationErrorFailureLevels)
   192  			})
   193  
   194  			t.Run("override", func(t *testing.T) {
   195  				m := NewMiddleware(Config{
   196  					Levels: LevelsConfig{
   197  						Outbound: DirectionalLevelsConfig{
   198  							Failure: &warnLevel,
   199  						},
   200  					},
   201  				})
   202  				assert.Equal(t, zapcore.WarnLevel, m.graph.outboundLevels.failure)
   203  				assert.True(t, m.graph.outboundLevels.useApplicationErrorFailureLevels)
   204  			})
   205  		})
   206  
   207  		t.Run("ApplicationError", func(t *testing.T) {
   208  			t.Run("default", func(t *testing.T) {
   209  				m := NewMiddleware(Config{})
   210  				assert.Equal(t, zapcore.ErrorLevel, m.graph.outboundLevels.applicationError)
   211  				assert.False(t, m.graph.outboundLevels.useApplicationErrorFailureLevels)
   212  			})
   213  
   214  			t.Run("override", func(t *testing.T) {
   215  				m := NewMiddleware(Config{
   216  					Levels: LevelsConfig{
   217  						Outbound: DirectionalLevelsConfig{
   218  							ApplicationError: &warnLevel,
   219  						},
   220  					},
   221  				})
   222  				assert.Equal(t, zapcore.WarnLevel, m.graph.outboundLevels.applicationError)
   223  				assert.True(t, m.graph.outboundLevels.useApplicationErrorFailureLevels)
   224  			})
   225  		})
   226  
   227  		t.Run("ClientError", func(t *testing.T) {
   228  			t.Run("default", func(t *testing.T) {
   229  				m := NewMiddleware(Config{})
   230  				assert.Equal(t, zapcore.ErrorLevel, m.graph.outboundLevels.clientError)
   231  				assert.False(t, m.graph.outboundLevels.useApplicationErrorFailureLevels)
   232  			})
   233  
   234  			t.Run("override", func(t *testing.T) {
   235  				m := NewMiddleware(Config{
   236  					Levels: LevelsConfig{
   237  						Outbound: DirectionalLevelsConfig{
   238  							ClientError: &warnLevel,
   239  						},
   240  					},
   241  				})
   242  				assert.Equal(t, zapcore.WarnLevel, m.graph.outboundLevels.clientError)
   243  				assert.False(t, m.graph.outboundLevels.useApplicationErrorFailureLevels)
   244  			})
   245  		})
   246  
   247  		t.Run("ServerError", func(t *testing.T) {
   248  			t.Run("default", func(t *testing.T) {
   249  				m := NewMiddleware(Config{})
   250  				assert.Equal(t, zapcore.ErrorLevel, m.graph.outboundLevels.serverError)
   251  				assert.False(t, m.graph.outboundLevels.useApplicationErrorFailureLevels)
   252  			})
   253  
   254  			t.Run("override", func(t *testing.T) {
   255  				m := NewMiddleware(Config{
   256  					Levels: LevelsConfig{
   257  						Outbound: DirectionalLevelsConfig{
   258  							ServerError: &warnLevel,
   259  						},
   260  					},
   261  				})
   262  				assert.Equal(t, zapcore.WarnLevel, m.graph.outboundLevels.serverError)
   263  				assert.False(t, m.graph.outboundLevels.useApplicationErrorFailureLevels)
   264  			})
   265  		})
   266  	})
   267  
   268  }
   269  
   270  func TestMiddlewareLoggingWithApplicationErrorConfiguration(t *testing.T) {
   271  	timeVal := time.Now()
   272  	defer stubTimeWithTimeVal(timeVal)()
   273  	ttl := time.Millisecond * 1000
   274  	ctx, cancel := context.WithDeadline(context.Background(), timeVal.Add(ttl))
   275  	defer cancel()
   276  
   277  	req := &transport.Request{
   278  		Caller:          "caller",
   279  		Service:         "service",
   280  		Transport:       "",
   281  		Encoding:        "raw",
   282  		Procedure:       "procedure",
   283  		Headers:         transport.NewHeaders().With("password", "super-secret"),
   284  		ShardKey:        "shard01",
   285  		RoutingKey:      "routing-key",
   286  		RoutingDelegate: "routing-delegate",
   287  		Body:            strings.NewReader("body"),
   288  	}
   289  
   290  	rawErr := errors.New("fail")
   291  	yErrNoDetails := yarpcerrors.Newf(yarpcerrors.CodeAborted, "fail")
   292  	yErrWithDetails := yarpcerrors.Newf(yarpcerrors.CodeAborted, "fail").WithDetails([]byte("err detail"))
   293  	yErrResourceExhausted := yarpcerrors.CodeResourceExhausted
   294  	appErrDetails := "an app error detail string, usually from thriftEx.Error()!"
   295  
   296  	baseFields := func() []zapcore.Field {
   297  		return []zapcore.Field{
   298  			zap.String("source", req.Caller),
   299  			zap.String("dest", req.Service),
   300  			zap.String("transport", unknownIfEmpty(req.Transport)),
   301  			zap.String("procedure", req.Procedure),
   302  			zap.String("encoding", string(req.Encoding)),
   303  			zap.String("routingKey", req.RoutingKey),
   304  			zap.String("routingDelegate", req.RoutingDelegate),
   305  		}
   306  	}
   307  
   308  	type test struct {
   309  		desc                  string
   310  		err                   error             // downstream error
   311  		applicationErr        bool              // downstream application error
   312  		applicationErrName    string            // downstream application error name
   313  		applicationErrDetails string            // downstream application error message
   314  		applicationErrCode    *yarpcerrors.Code // downstream application error code
   315  		wantErrLevel          zapcore.Level
   316  		wantInboundMsg        string
   317  		wantOutboundMsg       string
   318  		wantFields            []zapcore.Field
   319  	}
   320  
   321  	tests := []test{
   322  		{
   323  			desc:            "success",
   324  			wantErrLevel:    zapcore.InfoLevel,
   325  			wantInboundMsg:  "Handled inbound request.",
   326  			wantOutboundMsg: "Made outbound call.",
   327  			wantFields: []zapcore.Field{
   328  				zap.Duration("latency", 0),
   329  				zap.Bool("successful", true),
   330  				zap.Skip(), // ContextExtractor
   331  				zap.Duration("timeout", ttl),
   332  			},
   333  		},
   334  		{
   335  			desc:            "downstream transport error",
   336  			err:             rawErr,
   337  			wantErrLevel:    zapcore.ErrorLevel,
   338  			wantInboundMsg:  "Error handling inbound request.",
   339  			wantOutboundMsg: "Error making outbound call.",
   340  			wantFields: []zapcore.Field{
   341  				zap.Duration("latency", 0),
   342  				zap.Bool("successful", false),
   343  				zap.Skip(),
   344  				zap.Duration("timeout", ttl),
   345  				zap.Error(rawErr),
   346  				zap.String(_errorCodeLogKey, "unknown"),
   347  			},
   348  		},
   349  		{
   350  			desc:                  "thrift application error with no name",
   351  			applicationErr:        true,
   352  			applicationErrDetails: appErrDetails,
   353  			wantErrLevel:          zapcore.WarnLevel,
   354  			wantInboundMsg:        "Error handling inbound request.",
   355  			wantOutboundMsg:       "Error making outbound call.",
   356  			wantFields: []zapcore.Field{
   357  				zap.Duration("latency", 0),
   358  				zap.Bool("successful", false),
   359  				zap.Skip(),
   360  				zap.Duration("timeout", ttl),
   361  				zap.String("error", "application_error"),
   362  				zap.String("errorDetails", appErrDetails),
   363  			},
   364  		},
   365  		{
   366  			desc:                  "thrift application error with name and code",
   367  			applicationErr:        true,
   368  			applicationErrName:    "FunkyThriftError",
   369  			applicationErrDetails: appErrDetails,
   370  			applicationErrCode:    &yErrResourceExhausted,
   371  			wantErrLevel:          zapcore.WarnLevel,
   372  			wantInboundMsg:        "Error handling inbound request.",
   373  			wantOutboundMsg:       "Error making outbound call.",
   374  			wantFields: []zapcore.Field{
   375  				zap.Duration("latency", 0),
   376  				zap.Bool("successful", false),
   377  				zap.Skip(),
   378  				zap.Duration("timeout", ttl),
   379  				zap.String("error", "application_error"),
   380  				zap.String("errorCode", "resource-exhausted"),
   381  				zap.String("errorName", "FunkyThriftError"),
   382  				zap.String("errorDetails", appErrDetails),
   383  			},
   384  		},
   385  		{
   386  			// ie 'errors.New' return in Protobuf handler
   387  			desc:            "err and app error",
   388  			err:             rawErr,
   389  			applicationErr:  true, // always true for Protobuf handler errors
   390  			wantErrLevel:    zapcore.ErrorLevel,
   391  			wantInboundMsg:  "Error handling inbound request.",
   392  			wantOutboundMsg: "Error making outbound call.",
   393  			wantFields: []zapcore.Field{
   394  				zap.Duration("latency", 0),
   395  				zap.Bool("successful", false),
   396  				zap.Skip(),
   397  				zap.Duration("timeout", ttl),
   398  				zap.Error(rawErr),
   399  				zap.String(_errorCodeLogKey, "unknown"),
   400  			},
   401  		},
   402  		{
   403  			// ie 'yarpcerror' or 'protobuf.NewError` return in Protobuf handler
   404  			desc:            "yarpcerror, app error",
   405  			err:             yErrNoDetails,
   406  			applicationErr:  true, // always true for Protobuf handler errors
   407  			wantErrLevel:    zapcore.ErrorLevel,
   408  			wantInboundMsg:  "Error handling inbound request.",
   409  			wantOutboundMsg: "Error making outbound call.",
   410  			wantFields: []zapcore.Field{
   411  				zap.Duration("latency", 0),
   412  				zap.Bool("successful", false),
   413  				zap.Skip(),
   414  				zap.Duration("timeout", ttl),
   415  				zap.Error(yErrNoDetails),
   416  				zap.String(_errorCodeLogKey, "aborted"),
   417  			},
   418  		},
   419  		{
   420  			// ie 'protobuf.NewError' return in Protobuf handler
   421  			desc:                  "yarpcerror, app error with name and code",
   422  			err:                   yErrNoDetails,
   423  			applicationErr:        true, // always true for Protobuf handler errors
   424  			wantErrLevel:          zapcore.ErrorLevel,
   425  			applicationErrDetails: appErrDetails,
   426  			applicationErrName:    "MyErrMessageName",
   427  			wantInboundMsg:        "Error handling inbound request.",
   428  			wantOutboundMsg:       "Error making outbound call.",
   429  			wantFields: []zapcore.Field{
   430  				zap.Duration("latency", 0),
   431  				zap.Bool("successful", false),
   432  				zap.Skip(), // ContextExtractor
   433  				zap.Duration("timeout", ttl),
   434  				zap.Error(yErrNoDetails),
   435  				zap.String(_errorCodeLogKey, "aborted"),
   436  				zap.String(_errorNameLogKey, "MyErrMessageName"),
   437  				zap.String(_errorDetailsLogKey, appErrDetails),
   438  			},
   439  		},
   440  		{
   441  			// ie Protobuf error detail return in Protobuf handler
   442  			desc:            "err details, app error",
   443  			err:             yErrWithDetails,
   444  			applicationErr:  true, // always true for Protobuf handler errors
   445  			wantErrLevel:    zapcore.WarnLevel,
   446  			wantInboundMsg:  "Error handling inbound request.",
   447  			wantOutboundMsg: "Error making outbound call.",
   448  			wantFields: []zapcore.Field{
   449  				zap.Duration("latency", 0),
   450  				zap.Bool("successful", false),
   451  				zap.Skip(),
   452  				zap.Duration("timeout", ttl),
   453  				zap.Error(yErrWithDetails),
   454  				zap.String(_errorCodeLogKey, "aborted"),
   455  			},
   456  		},
   457  	}
   458  
   459  	newHandler := func(t test) fakeHandler {
   460  		return fakeHandler{
   461  			err:                   t.err,
   462  			applicationErr:        t.applicationErr,
   463  			applicationErrName:    t.applicationErrName,
   464  			applicationErrDetails: t.applicationErrDetails,
   465  			applicationErrCode:    t.applicationErrCode,
   466  		}
   467  	}
   468  
   469  	newOutbound := func(t test) fakeOutbound {
   470  		return fakeOutbound{
   471  			err:                   t.err,
   472  			applicationErr:        t.applicationErr,
   473  			applicationErrName:    t.applicationErrName,
   474  			applicationErrDetails: t.applicationErrDetails,
   475  			applicationErrCode:    t.applicationErrCode,
   476  		}
   477  	}
   478  
   479  	infoLevel := zapcore.InfoLevel
   480  	warnLevel := zapcore.WarnLevel
   481  
   482  	for _, tt := range tests {
   483  		core, logs := observer.New(zapcore.DebugLevel)
   484  		mw := NewMiddleware(Config{
   485  			Logger:           zap.New(core),
   486  			Scope:            metrics.New().Scope(),
   487  			ContextExtractor: NewNopContextExtractor(),
   488  			Levels: LevelsConfig{
   489  				Default: DirectionalLevelsConfig{
   490  					Success:          &infoLevel,
   491  					ApplicationError: &warnLevel,
   492  					// Leave failure level as the default.
   493  				},
   494  			},
   495  		})
   496  
   497  		getLog := func(t *testing.T) observer.LoggedEntry {
   498  			entries := logs.TakeAll()
   499  			require.Equal(t, 1, len(entries), "Unexpected number of logs written.")
   500  			e := entries[0]
   501  			e.Entry.Time = time.Time{}
   502  			return e
   503  		}
   504  
   505  		checkErr := func(err error) {
   506  			if tt.err != nil {
   507  				assert.Error(t, err, "Expected an error from middleware.")
   508  			} else {
   509  				assert.NoError(t, err, "Unexpected error from middleware.")
   510  			}
   511  		}
   512  
   513  		t.Run(tt.desc+", unary inbound", func(t *testing.T) {
   514  			err := mw.Handle(
   515  				ctx,
   516  				req,
   517  				&transporttest.FakeResponseWriter{},
   518  				newHandler(tt),
   519  			)
   520  			checkErr(err)
   521  			logContext := append(
   522  				baseFields(),
   523  				zap.String("direction", string(_directionInbound)),
   524  				zap.String("rpcType", "Unary"),
   525  			)
   526  			logContext = append(logContext, tt.wantFields...)
   527  			expected := observer.LoggedEntry{
   528  				Entry: zapcore.Entry{
   529  					Level:   tt.wantErrLevel,
   530  					Message: tt.wantInboundMsg,
   531  				},
   532  				Context: logContext,
   533  			}
   534  			assert.Equal(t, expected, getLog(t), "Unexpected log entry written.")
   535  		})
   536  		t.Run(tt.desc+", unary outbound", func(t *testing.T) {
   537  			res, err := mw.Call(ctx, req, newOutbound(tt))
   538  			checkErr(err)
   539  			if tt.err == nil {
   540  				assert.NotNil(t, res, "Expected non-nil response if call is successful.")
   541  			}
   542  			logContext := append(
   543  				baseFields(),
   544  				zap.String("direction", string(_directionOutbound)),
   545  				zap.String("rpcType", "Unary"),
   546  			)
   547  			logContext = append(logContext, tt.wantFields...)
   548  			expected := observer.LoggedEntry{
   549  				Entry: zapcore.Entry{
   550  					Level:   tt.wantErrLevel,
   551  					Message: tt.wantOutboundMsg,
   552  				},
   553  				Context: logContext,
   554  			}
   555  			assert.Equal(t, expected, getLog(t), "Unexpected log entry written.")
   556  		})
   557  
   558  		// Application errors aren't applicable to oneway and streaming
   559  		if tt.applicationErr {
   560  			continue
   561  		}
   562  
   563  		t.Run(tt.desc+", oneway inbound", func(t *testing.T) {
   564  			err := mw.HandleOneway(ctx, req, newHandler(tt))
   565  			checkErr(err)
   566  			logContext := append(
   567  				baseFields(),
   568  				zap.String("direction", string(_directionInbound)),
   569  				zap.String("rpcType", "Oneway"),
   570  			)
   571  			logContext = append(logContext, tt.wantFields...)
   572  			expected := observer.LoggedEntry{
   573  				Entry: zapcore.Entry{
   574  					Level:   tt.wantErrLevel,
   575  					Message: tt.wantInboundMsg,
   576  				},
   577  				Context: logContext,
   578  			}
   579  			assert.Equal(t, expected, getLog(t), "Unexpected log entry written.")
   580  		})
   581  		t.Run(tt.desc+", oneway outbound", func(t *testing.T) {
   582  			ack, err := mw.CallOneway(ctx, req, newOutbound(tt))
   583  			checkErr(err)
   584  			logContext := append(
   585  				baseFields(),
   586  				zap.String("direction", string(_directionOutbound)),
   587  				zap.String("rpcType", "Oneway"),
   588  			)
   589  			logContext = append(logContext, tt.wantFields...)
   590  			if tt.err == nil {
   591  				assert.NotNil(t, ack, "Expected non-nil ack if call is successful.")
   592  			}
   593  			expected := observer.LoggedEntry{
   594  				Entry: zapcore.Entry{
   595  					Level:   tt.wantErrLevel,
   596  					Message: tt.wantOutboundMsg,
   597  				},
   598  				Context: logContext,
   599  			}
   600  			assert.Equal(t, expected, getLog(t), "Unexpected log entry written.")
   601  		})
   602  	}
   603  }
   604  
   605  func TestMiddlewareLoggingWithServerErrorConfiguration(t *testing.T) {
   606  	timeVal := time.Now()
   607  	defer stubTimeWithTimeVal(timeVal)()
   608  	ttl := time.Millisecond * 1000
   609  	ctx, cancel := context.WithDeadline(context.Background(), timeVal.Add(ttl))
   610  	defer cancel()
   611  
   612  	req := &transport.Request{
   613  		Caller:          "caller",
   614  		Service:         "service",
   615  		Transport:       "",
   616  		Encoding:        "raw",
   617  		Procedure:       "procedure",
   618  		Headers:         transport.NewHeaders().With("password", "super-secret"),
   619  		ShardKey:        "shard01",
   620  		RoutingKey:      "routing-key",
   621  		RoutingDelegate: "routing-delegate",
   622  		Body:            strings.NewReader("body"),
   623  	}
   624  
   625  	rawErr := errors.New("fail")
   626  	yErrNoDetails := yarpcerrors.Newf(yarpcerrors.CodeAborted, "fail")
   627  	yErrWithDetails := yarpcerrors.Newf(yarpcerrors.CodeAborted, "fail").WithDetails([]byte("err detail"))
   628  	yErrResourceExhausted := yarpcerrors.CodeResourceExhausted
   629  	yErrInternal := yarpcerrors.CodeInternal
   630  	appErrDetails := "an app error detail string, usually from thriftEx.Error()!"
   631  	yServerErrInternal := yarpcerrors.Newf(yarpcerrors.CodeInternal, "internal")
   632  
   633  	baseFields := func() []zapcore.Field {
   634  		return []zapcore.Field{
   635  			zap.String("source", req.Caller),
   636  			zap.String("dest", req.Service),
   637  			zap.String("transport", unknownIfEmpty(req.Transport)),
   638  			zap.String("procedure", req.Procedure),
   639  			zap.String("encoding", string(req.Encoding)),
   640  			zap.String("routingKey", req.RoutingKey),
   641  			zap.String("routingDelegate", req.RoutingDelegate),
   642  		}
   643  	}
   644  
   645  	type test struct {
   646  		desc                  string
   647  		err                   error             // downstream error
   648  		applicationErr        bool              // downstream application error
   649  		applicationErrName    string            // downstream application error name
   650  		applicationErrDetails string            // downstream application error message
   651  		applicationErrCode    *yarpcerrors.Code // downstream application error code
   652  		wantErrLevel          zapcore.Level
   653  		wantInboundMsg        string
   654  		wantOutboundMsg       string
   655  		wantFields            []zapcore.Field
   656  	}
   657  
   658  	tests := []test{
   659  		{
   660  			desc:            "success",
   661  			wantErrLevel:    zapcore.InfoLevel,
   662  			wantInboundMsg:  "Handled inbound request.",
   663  			wantOutboundMsg: "Made outbound call.",
   664  			wantFields: []zapcore.Field{
   665  				zap.Duration("latency", 0),
   666  				zap.Bool("successful", true),
   667  				zap.Skip(), // ContextExtractor
   668  				zap.Duration("timeout", ttl),
   669  			},
   670  		},
   671  		{
   672  			desc:            "downstream transport error",
   673  			err:             rawErr,
   674  			wantErrLevel:    zapcore.ErrorLevel,
   675  			wantInboundMsg:  "Error handling inbound request.",
   676  			wantOutboundMsg: "Error making outbound call.",
   677  			wantFields: []zapcore.Field{
   678  				zap.Duration("latency", 0),
   679  				zap.Bool("successful", false),
   680  				zap.Skip(),
   681  				zap.Duration("timeout", ttl),
   682  				zap.Error(rawErr),
   683  				zap.String(_errorCodeLogKey, "unknown"),
   684  			},
   685  		},
   686  		{
   687  			desc:                  "thrift application error with no name",
   688  			applicationErr:        true,
   689  			applicationErrDetails: appErrDetails,
   690  			wantErrLevel:          zapcore.WarnLevel,
   691  			wantInboundMsg:        "Error handling inbound request.",
   692  			wantOutboundMsg:       "Error making outbound call.",
   693  			wantFields: []zapcore.Field{
   694  				zap.Duration("latency", 0),
   695  				zap.Bool("successful", false),
   696  				zap.Skip(),
   697  				zap.Duration("timeout", ttl),
   698  				zap.String("error", "application_error"),
   699  				zap.String("errorDetails", appErrDetails),
   700  			},
   701  		},
   702  		{
   703  			desc:                  "thrift application error with name and code",
   704  			applicationErr:        true,
   705  			applicationErrName:    "FunkyThriftError",
   706  			applicationErrDetails: appErrDetails,
   707  			applicationErrCode:    &yErrResourceExhausted,
   708  			wantErrLevel:          zapcore.WarnLevel,
   709  			wantInboundMsg:        "Error handling inbound request.",
   710  			wantOutboundMsg:       "Error making outbound call.",
   711  			wantFields: []zapcore.Field{
   712  				zap.Duration("latency", 0),
   713  				zap.Bool("successful", false),
   714  				zap.Skip(),
   715  				zap.Duration("timeout", ttl),
   716  				zap.String("error", "application_error"),
   717  				zap.String("errorCode", "resource-exhausted"),
   718  				zap.String("errorName", "FunkyThriftError"),
   719  				zap.String("errorDetails", appErrDetails),
   720  			},
   721  		},
   722  		{
   723  			desc:                  "thrift application error with name and code (internal code error)",
   724  			applicationErr:        true,
   725  			applicationErrName:    "FunkyThriftError",
   726  			applicationErrDetails: appErrDetails,
   727  			applicationErrCode:    &yErrInternal,
   728  			wantErrLevel:          zapcore.ErrorLevel,
   729  			wantInboundMsg:        "Error handling inbound request.",
   730  			wantOutboundMsg:       "Error making outbound call.",
   731  			wantFields: []zapcore.Field{
   732  				zap.Duration("latency", 0),
   733  				zap.Bool("successful", false),
   734  				zap.Skip(),
   735  				zap.Duration("timeout", ttl),
   736  				zap.String("error", "application_error"),
   737  				zap.String("errorCode", "internal"),
   738  				zap.String("errorName", "FunkyThriftError"),
   739  				zap.String("errorDetails", appErrDetails),
   740  			},
   741  		},
   742  		{
   743  			// ie 'errors.New' return in Protobuf handler
   744  			desc:            "err and app error",
   745  			err:             rawErr,
   746  			applicationErr:  true, // always true for Protobuf handler errors
   747  			wantErrLevel:    zapcore.ErrorLevel,
   748  			wantInboundMsg:  "Error handling inbound request.",
   749  			wantOutboundMsg: "Error making outbound call.",
   750  			wantFields: []zapcore.Field{
   751  				zap.Duration("latency", 0),
   752  				zap.Bool("successful", false),
   753  				zap.Skip(),
   754  				zap.Duration("timeout", ttl),
   755  				zap.Error(rawErr),
   756  				zap.String(_errorCodeLogKey, "unknown"),
   757  			},
   758  		},
   759  		{
   760  			// ie 'yarpcerror' or 'protobuf.NewError` return in Protobuf handler
   761  			desc:            "yarpcerror, app error",
   762  			err:             yErrNoDetails,
   763  			applicationErr:  true, // always true for Protobuf handler errors
   764  			wantErrLevel:    zapcore.WarnLevel,
   765  			wantInboundMsg:  "Error handling inbound request.",
   766  			wantOutboundMsg: "Error making outbound call.",
   767  			wantFields: []zapcore.Field{
   768  				zap.Duration("latency", 0),
   769  				zap.Bool("successful", false),
   770  				zap.Skip(),
   771  				zap.Duration("timeout", ttl),
   772  				zap.Error(yErrNoDetails),
   773  				zap.String(_errorCodeLogKey, "aborted"),
   774  			},
   775  		},
   776  		{
   777  			// ie 'protobuf.NewError' return in Protobuf handler
   778  			desc:                  "yarpcerror, app error with name and code",
   779  			err:                   yErrNoDetails,
   780  			applicationErr:        true, // always true for Protobuf handler errors
   781  			wantErrLevel:          zapcore.WarnLevel,
   782  			applicationErrDetails: appErrDetails,
   783  			applicationErrName:    "MyErrMessageName",
   784  			wantInboundMsg:        "Error handling inbound request.",
   785  			wantOutboundMsg:       "Error making outbound call.",
   786  			wantFields: []zapcore.Field{
   787  				zap.Duration("latency", 0),
   788  				zap.Bool("successful", false),
   789  				zap.Skip(), // ContextExtractor
   790  				zap.Duration("timeout", ttl),
   791  				zap.Error(yErrNoDetails),
   792  				zap.String(_errorCodeLogKey, "aborted"),
   793  				zap.String(_errorNameLogKey, "MyErrMessageName"),
   794  				zap.String(_errorDetailsLogKey, appErrDetails),
   795  			},
   796  		},
   797  		{
   798  			// ie Protobuf error detail return in Protobuf handler
   799  			desc:            "err details, app error",
   800  			err:             yErrWithDetails,
   801  			applicationErr:  true, // always true for Protobuf handler errors
   802  			wantErrLevel:    zapcore.WarnLevel,
   803  			wantInboundMsg:  "Error handling inbound request.",
   804  			wantOutboundMsg: "Error making outbound call.",
   805  			wantFields: []zapcore.Field{
   806  				zap.Duration("latency", 0),
   807  				zap.Bool("successful", false),
   808  				zap.Skip(),
   809  				zap.Duration("timeout", ttl),
   810  				zap.Error(yErrWithDetails),
   811  				zap.String(_errorCodeLogKey, "aborted"),
   812  			},
   813  		},
   814  		{
   815  			// ie Protobuf error internal
   816  			desc:            "err internal, app error",
   817  			err:             yServerErrInternal,
   818  			applicationErr:  true, // always true for Protobuf handler errors
   819  			wantErrLevel:    zapcore.ErrorLevel,
   820  			wantInboundMsg:  "Error handling inbound request.",
   821  			wantOutboundMsg: "Error making outbound call.",
   822  			wantFields: []zapcore.Field{
   823  				zap.Duration("latency", 0),
   824  				zap.Bool("successful", false),
   825  				zap.Skip(),
   826  				zap.Duration("timeout", ttl),
   827  				zap.Error(yServerErrInternal),
   828  				zap.String(_errorCodeLogKey, "internal"),
   829  			},
   830  		},
   831  	}
   832  
   833  	newHandler := func(t test) fakeHandler {
   834  		return fakeHandler{
   835  			err:                   t.err,
   836  			applicationErr:        t.applicationErr,
   837  			applicationErrName:    t.applicationErrName,
   838  			applicationErrDetails: t.applicationErrDetails,
   839  			applicationErrCode:    t.applicationErrCode,
   840  		}
   841  	}
   842  
   843  	newOutbound := func(t test) fakeOutbound {
   844  		return fakeOutbound{
   845  			err:                   t.err,
   846  			applicationErr:        t.applicationErr,
   847  			applicationErrName:    t.applicationErrName,
   848  			applicationErrDetails: t.applicationErrDetails,
   849  			applicationErrCode:    t.applicationErrCode,
   850  		}
   851  	}
   852  
   853  	infoLevel := zapcore.InfoLevel
   854  	warnLevel := zapcore.WarnLevel
   855  
   856  	for _, tt := range tests {
   857  		core, logs := observer.New(zapcore.DebugLevel)
   858  		mw := NewMiddleware(Config{
   859  			Logger:           zap.New(core),
   860  			Scope:            metrics.New().Scope(),
   861  			ContextExtractor: NewNopContextExtractor(),
   862  			Levels: LevelsConfig{
   863  				Default: DirectionalLevelsConfig{
   864  					Success:     &infoLevel,
   865  					ClientError: &warnLevel,
   866  					// Leave failure level as the default.
   867  				},
   868  			},
   869  		})
   870  
   871  		getLog := func(t *testing.T) observer.LoggedEntry {
   872  			entries := logs.TakeAll()
   873  			require.Equal(t, 1, len(entries), "Unexpected number of logs written.")
   874  			e := entries[0]
   875  			e.Entry.Time = time.Time{}
   876  			return e
   877  		}
   878  
   879  		checkErr := func(err error) {
   880  			if tt.err != nil {
   881  				assert.Error(t, err, "Expected an error from middleware.")
   882  			} else {
   883  				assert.NoError(t, err, "Unexpected error from middleware.")
   884  			}
   885  		}
   886  
   887  		t.Run(tt.desc+", unary inbound", func(t *testing.T) {
   888  			err := mw.Handle(
   889  				ctx,
   890  				req,
   891  				&transporttest.FakeResponseWriter{},
   892  				newHandler(tt),
   893  			)
   894  			checkErr(err)
   895  			logContext := append(
   896  				baseFields(),
   897  				zap.String("direction", string(_directionInbound)),
   898  				zap.String("rpcType", "Unary"),
   899  			)
   900  			logContext = append(logContext, tt.wantFields...)
   901  			expected := observer.LoggedEntry{
   902  				Entry: zapcore.Entry{
   903  					Level:   tt.wantErrLevel,
   904  					Message: tt.wantInboundMsg,
   905  				},
   906  				Context: logContext,
   907  			}
   908  			assert.Equal(t, expected, getLog(t), "Unexpected log entry written.")
   909  		})
   910  		t.Run(tt.desc+", unary outbound", func(t *testing.T) {
   911  			res, err := mw.Call(ctx, req, newOutbound(tt))
   912  			checkErr(err)
   913  			if tt.err == nil {
   914  				assert.NotNil(t, res, "Expected non-nil response if call is successful.")
   915  			}
   916  			logContext := append(
   917  				baseFields(),
   918  				zap.String("direction", string(_directionOutbound)),
   919  				zap.String("rpcType", "Unary"),
   920  			)
   921  			logContext = append(logContext, tt.wantFields...)
   922  			expected := observer.LoggedEntry{
   923  				Entry: zapcore.Entry{
   924  					Level:   tt.wantErrLevel,
   925  					Message: tt.wantOutboundMsg,
   926  				},
   927  				Context: logContext,
   928  			}
   929  			assert.Equal(t, expected, getLog(t), "Unexpected log entry written.")
   930  		})
   931  
   932  		// Application errors aren't applicable to oneway and streaming
   933  		if tt.applicationErr {
   934  			continue
   935  		}
   936  
   937  		t.Run(tt.desc+", oneway inbound", func(t *testing.T) {
   938  			err := mw.HandleOneway(ctx, req, newHandler(tt))
   939  			checkErr(err)
   940  			logContext := append(
   941  				baseFields(),
   942  				zap.String("direction", string(_directionInbound)),
   943  				zap.String("rpcType", "Oneway"),
   944  			)
   945  			logContext = append(logContext, tt.wantFields...)
   946  			expected := observer.LoggedEntry{
   947  				Entry: zapcore.Entry{
   948  					Level:   tt.wantErrLevel,
   949  					Message: tt.wantInboundMsg,
   950  				},
   951  				Context: logContext,
   952  			}
   953  			assert.Equal(t, expected, getLog(t), "Unexpected log entry written.")
   954  		})
   955  		t.Run(tt.desc+", oneway outbound", func(t *testing.T) {
   956  			ack, err := mw.CallOneway(ctx, req, newOutbound(tt))
   957  			checkErr(err)
   958  			logContext := append(
   959  				baseFields(),
   960  				zap.String("direction", string(_directionOutbound)),
   961  				zap.String("rpcType", "Oneway"),
   962  			)
   963  			logContext = append(logContext, tt.wantFields...)
   964  			if tt.err == nil {
   965  				assert.NotNil(t, ack, "Expected non-nil ack if call is successful.")
   966  			}
   967  			expected := observer.LoggedEntry{
   968  				Entry: zapcore.Entry{
   969  					Level:   tt.wantErrLevel,
   970  					Message: tt.wantOutboundMsg,
   971  				},
   972  				Context: logContext,
   973  			}
   974  			assert.Equal(t, expected, getLog(t), "Unexpected log entry written.")
   975  		})
   976  	}
   977  }
   978  
   979  func TestMiddlewareStreamingSuccess(t *testing.T) {
   980  	defer stubTime()()
   981  	req := &transport.StreamRequest{
   982  		Meta: &transport.RequestMeta{
   983  			Caller:          "caller",
   984  			Service:         "service",
   985  			Transport:       "transport",
   986  			Encoding:        "raw",
   987  			Procedure:       "procedure",
   988  			Headers:         transport.NewHeaders().With("hello!", "goodbye!"),
   989  			ShardKey:        "shard-key",
   990  			RoutingKey:      "routing-key",
   991  			RoutingDelegate: "routing-delegate",
   992  		},
   993  	}
   994  
   995  	// helper function to creating logging fields for assertion
   996  	newZapFields := func(extraFields ...zapcore.Field) []zapcore.Field {
   997  		fields := []zapcore.Field{
   998  			zap.String("source", req.Meta.Caller),
   999  			zap.String("dest", req.Meta.Service),
  1000  			zap.String("transport", req.Meta.Transport),
  1001  			zap.String("procedure", req.Meta.Procedure),
  1002  			zap.String("encoding", string(req.Meta.Encoding)),
  1003  			zap.String("routingKey", req.Meta.RoutingKey),
  1004  			zap.String("routingDelegate", req.Meta.RoutingDelegate),
  1005  		}
  1006  		return append(fields, extraFields...)
  1007  	}
  1008  
  1009  	// create middleware
  1010  	core, logs := observer.New(zapcore.DebugLevel)
  1011  	infoLevel := zapcore.InfoLevel
  1012  	errorLevel := zapcore.ErrorLevel
  1013  	mw := NewMiddleware(Config{
  1014  		Logger:           zap.New(core),
  1015  		Scope:            metrics.New().Scope(),
  1016  		ContextExtractor: NewNopContextExtractor(),
  1017  		Levels: LevelsConfig{
  1018  			Default: DirectionalLevelsConfig{
  1019  				Success: &infoLevel,
  1020  				Failure: &errorLevel,
  1021  			},
  1022  		},
  1023  	})
  1024  
  1025  	// helper function to retrieve observed logs, asserting the expected number
  1026  	getLogs := func(t *testing.T, num int) []observer.LoggedEntry {
  1027  		logs := logs.TakeAll()
  1028  		require.Equal(t, num, len(logs), "expected exactly %d logs, got %v: %#v", num, len(logs), logs)
  1029  
  1030  		var entries []observer.LoggedEntry
  1031  		for _, e := range logs {
  1032  			// zero the time for easy comparisons
  1033  			e.Entry.Time = time.Time{}
  1034  			entries = append(entries, e)
  1035  		}
  1036  		return entries
  1037  	}
  1038  
  1039  	t.Run("success server", func(t *testing.T) {
  1040  		stream, err := transport.NewServerStream(&fakeStream{request: req})
  1041  		require.NoError(t, err)
  1042  
  1043  		err = mw.HandleStream(stream, &fakeHandler{
  1044  			// send and receive messages in the handler
  1045  			handleStream: func(stream *transport.ServerStream) {
  1046  				err := stream.SendMessage(context.Background(), nil /*message*/)
  1047  				require.NoError(t, err)
  1048  				_, err = stream.ReceiveMessage(context.Background())
  1049  				require.NoError(t, err)
  1050  			}})
  1051  		require.NoError(t, err)
  1052  
  1053  		logFields := func() []zapcore.Field {
  1054  			return newZapFields(
  1055  				zap.String("direction", string(_directionInbound)),
  1056  				zap.String("rpcType", "Streaming"),
  1057  				zap.Bool("successful", true),
  1058  				zap.Skip(), // context extractor
  1059  				zap.Skip(), // nil error
  1060  			)
  1061  		}
  1062  
  1063  		wantLogs := []observer.LoggedEntry{
  1064  			{
  1065  				// open stream
  1066  				Entry: zapcore.Entry{
  1067  					Message: _successStreamOpen,
  1068  				},
  1069  				Context: logFields(),
  1070  			},
  1071  			{
  1072  				// send message
  1073  				Entry: zapcore.Entry{
  1074  					Message: _successfulStreamSend,
  1075  				},
  1076  				Context: logFields(),
  1077  			},
  1078  			{
  1079  				// receive message
  1080  				Entry: zapcore.Entry{
  1081  					Message: _successfulStreamReceive,
  1082  				},
  1083  				Context: logFields(),
  1084  			},
  1085  			{
  1086  				// close stream
  1087  				Entry: zapcore.Entry{
  1088  					Message: _successStreamClose,
  1089  				},
  1090  				Context: append(logFields(), zap.Duration("duration", 0)),
  1091  			},
  1092  		}
  1093  
  1094  		// log 1 - open stream
  1095  		// log 2 - send message
  1096  		// log 3 - receive message
  1097  		// log 4 - close stream
  1098  		gotLogs := getLogs(t, 4)
  1099  		assert.Equal(t, wantLogs, gotLogs)
  1100  	})
  1101  
  1102  	t.Run("success client", func(t *testing.T) {
  1103  		stream, err := mw.CallStream(context.Background(), req, fakeOutbound{})
  1104  		require.NoError(t, err)
  1105  		err = stream.SendMessage(context.Background(), nil /* message */)
  1106  		require.NoError(t, err)
  1107  		_, err = stream.ReceiveMessage(context.Background())
  1108  		require.NoError(t, err)
  1109  		require.NoError(t, stream.Close(context.Background()))
  1110  
  1111  		fields := func() []zapcore.Field {
  1112  			return newZapFields(
  1113  				zap.String("direction", string(_directionOutbound)),
  1114  				zap.String("rpcType", "Streaming"),
  1115  				zap.Bool("successful", true),
  1116  				zap.Skip(), // context extractor
  1117  				zap.Skip(), // nil error
  1118  			)
  1119  		}
  1120  
  1121  		wantLogs := []observer.LoggedEntry{
  1122  			{
  1123  				// stream open
  1124  				Entry: zapcore.Entry{
  1125  					Message: _successStreamOpen,
  1126  				},
  1127  				Context: fields(),
  1128  			},
  1129  			{
  1130  				// stream send
  1131  				Entry: zapcore.Entry{
  1132  					Message: _successfulStreamSend,
  1133  				},
  1134  				Context: fields(),
  1135  			},
  1136  			{
  1137  				// stream receive
  1138  				Entry: zapcore.Entry{
  1139  					Message: _successfulStreamReceive,
  1140  				},
  1141  				Context: fields(),
  1142  			},
  1143  			{
  1144  				// stream close
  1145  				Entry: zapcore.Entry{
  1146  					Message: _successStreamClose,
  1147  				},
  1148  				Context: append(fields(), zap.Duration("duration", 0)),
  1149  			},
  1150  		}
  1151  
  1152  		// log 1 - open stream
  1153  		// log 2 - send message
  1154  		// log 3 - receive message
  1155  		// log 4 - close stream
  1156  		gotLogs := getLogs(t, 4)
  1157  		assert.Equal(t, wantLogs, gotLogs)
  1158  	})
  1159  }
  1160  
  1161  func TestMiddlewareStreamingLoggingErrorWithFailureConfiguration(t *testing.T) {
  1162  	defer stubTime()()
  1163  	req := &transport.StreamRequest{
  1164  		Meta: &transport.RequestMeta{
  1165  			Caller:          "caller",
  1166  			Service:         "service",
  1167  			Transport:       "transport",
  1168  			Encoding:        "raw",
  1169  			Procedure:       "procedure",
  1170  			Headers:         transport.NewHeaders().With("hello!", "goodbye!"),
  1171  			ShardKey:        "shard-key",
  1172  			RoutingKey:      "routing-key",
  1173  			RoutingDelegate: "routing-delegate",
  1174  		},
  1175  	}
  1176  
  1177  	// helper function to creating logging fields for assertion
  1178  	newZapFields := func(extraFields ...zapcore.Field) []zapcore.Field {
  1179  		fields := []zapcore.Field{
  1180  			zap.String("source", req.Meta.Caller),
  1181  			zap.String("dest", req.Meta.Service),
  1182  			zap.String("transport", req.Meta.Transport),
  1183  			zap.String("procedure", req.Meta.Procedure),
  1184  			zap.String("encoding", string(req.Meta.Encoding)),
  1185  			zap.String("routingKey", req.Meta.RoutingKey),
  1186  			zap.String("routingDelegate", req.Meta.RoutingDelegate),
  1187  		}
  1188  		return append(fields, extraFields...)
  1189  	}
  1190  
  1191  	// create middleware
  1192  	core, logs := observer.New(zapcore.DebugLevel)
  1193  	infoLevel := zapcore.InfoLevel
  1194  	errorLevel := zapcore.ErrorLevel
  1195  	mw := NewMiddleware(Config{
  1196  		Logger:           zap.New(core),
  1197  		Scope:            metrics.New().Scope(),
  1198  		ContextExtractor: NewNopContextExtractor(),
  1199  		Levels: LevelsConfig{
  1200  			Default: DirectionalLevelsConfig{
  1201  				Success: &infoLevel,
  1202  				Failure: &errorLevel,
  1203  			},
  1204  		},
  1205  	})
  1206  
  1207  	// helper function to retrieve observed logs, asserting the expected number
  1208  	getLogs := func(t *testing.T, num int) []observer.LoggedEntry {
  1209  		logs := logs.TakeAll()
  1210  		require.Equal(t, num, len(logs), "expected exactly %d logs, got %v: %#v", num, len(logs), logs)
  1211  
  1212  		var entries []observer.LoggedEntry
  1213  		for _, e := range logs {
  1214  			// zero the time for easy comparisons
  1215  			e.Entry.Time = time.Time{}
  1216  			entries = append(entries, e)
  1217  		}
  1218  		return entries
  1219  	}
  1220  
  1221  	t.Run("error handler", func(t *testing.T) {
  1222  		tests := []struct {
  1223  			name string
  1224  			err  error
  1225  		}{
  1226  			{
  1227  				name: "client fault",
  1228  				err:  yarpcerrors.InvalidArgumentErrorf("client err"),
  1229  			},
  1230  			{
  1231  				name: "server fault",
  1232  				err:  yarpcerrors.InternalErrorf("server err"),
  1233  			},
  1234  			{
  1235  				name: "unknown fault",
  1236  				err:  errors.New("unknown fault"),
  1237  			},
  1238  		}
  1239  
  1240  		for _, tt := range tests {
  1241  			t.Run(tt.name, func(t *testing.T) {
  1242  				stream, err := transport.NewServerStream(&fakeStream{request: req})
  1243  				require.NoError(t, err)
  1244  
  1245  				err = mw.HandleStream(stream, &fakeHandler{err: tt.err})
  1246  				require.Error(t, err)
  1247  
  1248  				fields := newZapFields(
  1249  					zap.String("direction", string(_directionInbound)),
  1250  					zap.String("rpcType", "Streaming"),
  1251  					zap.Bool("successful", false),
  1252  					zap.Skip(), // context extractor
  1253  					zap.Error(tt.err),
  1254  					zap.Duration("duration", 0),
  1255  				)
  1256  
  1257  				wantLog := observer.LoggedEntry{
  1258  					Entry: zapcore.Entry{
  1259  						Message: _errorStreamClose,
  1260  						Level:   zapcore.ErrorLevel,
  1261  					},
  1262  					Context: fields,
  1263  				}
  1264  
  1265  				// The stream handler is only executed after a stream successfully connects
  1266  				// with a client. Therefore the first streaming log will always be
  1267  				// successful (tested in the previous subtest). We only care about the
  1268  				// stream termination so we retrieve the last log.
  1269  				//
  1270  				// log 1 - open stream
  1271  				// log 2 - close stream
  1272  				gotLog := getLogs(t, 2)[1]
  1273  				assert.Equal(t, wantLog, gotLog)
  1274  			})
  1275  		}
  1276  	})
  1277  
  1278  	t.Run("error server - send and recv", func(t *testing.T) {
  1279  		sendErr := errors.New("send err")
  1280  		receiveErr := errors.New("receive err")
  1281  
  1282  		stream, err := transport.NewServerStream(&fakeStream{
  1283  			request:    req,
  1284  			sendErr:    sendErr,
  1285  			receiveErr: receiveErr,
  1286  		})
  1287  		require.NoError(t, err)
  1288  
  1289  		err = mw.HandleStream(stream, &fakeHandler{
  1290  			// send and receive messages in the handler
  1291  			handleStream: func(stream *transport.ServerStream) {
  1292  				err := stream.SendMessage(context.Background(), nil /*message*/)
  1293  				require.Error(t, err)
  1294  				_, err = stream.ReceiveMessage(context.Background())
  1295  				require.Error(t, err)
  1296  			}})
  1297  		require.NoError(t, err)
  1298  
  1299  		fields := func() []zapcore.Field {
  1300  			return newZapFields(
  1301  				zap.String("direction", string(_directionInbound)),
  1302  				zap.String("rpcType", "Streaming"),
  1303  				zap.Bool("successful", false),
  1304  				zap.Skip(), // context extractor
  1305  			)
  1306  		}
  1307  
  1308  		wantLogs := []observer.LoggedEntry{
  1309  			{
  1310  				// send message
  1311  				Entry: zapcore.Entry{
  1312  					Message: _errorStreamSend,
  1313  					Level:   zapcore.ErrorLevel,
  1314  				},
  1315  				Context: append(fields(), zap.Error(sendErr)),
  1316  			},
  1317  			{
  1318  				// receive message
  1319  				Entry: zapcore.Entry{
  1320  					Message: _errorStreamReceive,
  1321  					Level:   zapcore.ErrorLevel,
  1322  				},
  1323  				Context: append(fields(), zap.Error(receiveErr)),
  1324  			},
  1325  		}
  1326  
  1327  		// We are only interested in the send and receive logs.
  1328  		// log 1 - open stream
  1329  		// log 2 - send message
  1330  		// log 3 - receive message
  1331  		// log 4 - close stream
  1332  		gotLogs := getLogs(t, 4)[1:3]
  1333  		assert.Equal(t, wantLogs, gotLogs)
  1334  	})
  1335  
  1336  	t.Run("error client handshake", func(t *testing.T) {
  1337  		clientErr := errors.New("client err")
  1338  		_, err := mw.CallStream(context.Background(), req, fakeOutbound{err: clientErr})
  1339  		require.Error(t, err)
  1340  
  1341  		fields := func() []zapcore.Field {
  1342  			return newZapFields(
  1343  				zap.String("direction", string(_directionOutbound)),
  1344  				zap.String("rpcType", "Streaming"),
  1345  				zap.Bool("successful", false),
  1346  				zap.Skip(), // context extractor
  1347  				zap.Error(clientErr),
  1348  			)
  1349  		}
  1350  
  1351  		wantLogs := []observer.LoggedEntry{
  1352  			{
  1353  				// stream open
  1354  				Entry: zapcore.Entry{
  1355  					Message: _errorStreamOpen,
  1356  					Level:   zapcore.ErrorLevel,
  1357  				},
  1358  				Context: fields(),
  1359  			},
  1360  		}
  1361  
  1362  		// log 1 - open stream
  1363  		gotLogs := getLogs(t, 1)
  1364  		assert.Equal(t, wantLogs, gotLogs)
  1365  	})
  1366  
  1367  	t.Run("error client - send recv close", func(t *testing.T) {
  1368  		sendErr := errors.New("send err")
  1369  		receiveErr := errors.New("receive err")
  1370  		closeErr := errors.New("close err")
  1371  
  1372  		stream, err := mw.CallStream(context.Background(), req, fakeOutbound{
  1373  			stream: fakeStream{
  1374  				sendErr:    sendErr,
  1375  				receiveErr: receiveErr,
  1376  				closeErr:   closeErr,
  1377  			}})
  1378  		require.NoError(t, err)
  1379  
  1380  		err = stream.SendMessage(context.Background(), nil /* message */)
  1381  		require.Error(t, err)
  1382  		_, err = stream.ReceiveMessage(context.Background())
  1383  		require.Error(t, err)
  1384  		err = stream.Close(context.Background())
  1385  		require.Error(t, err)
  1386  
  1387  		fields := func() []zapcore.Field {
  1388  			return newZapFields(
  1389  				zap.String("direction", string(_directionOutbound)),
  1390  				zap.String("rpcType", "Streaming"),
  1391  				zap.Bool("successful", false),
  1392  				zap.Skip(), // context extractor
  1393  			)
  1394  		}
  1395  
  1396  		wantLogs := []observer.LoggedEntry{
  1397  			{
  1398  				// send message
  1399  				Entry: zapcore.Entry{
  1400  					Message: _errorStreamSend,
  1401  					Level:   zapcore.ErrorLevel,
  1402  				},
  1403  				Context: append(fields(), zap.Error(sendErr)),
  1404  			},
  1405  			{
  1406  				// receive message
  1407  				Entry: zapcore.Entry{
  1408  					Message: _errorStreamReceive,
  1409  					Level:   zapcore.ErrorLevel,
  1410  				},
  1411  				Context: append(fields(), zap.Error(receiveErr)),
  1412  			},
  1413  			{
  1414  				// close stream
  1415  				Entry: zapcore.Entry{
  1416  					Message: _errorStreamClose,
  1417  					Level:   zapcore.ErrorLevel,
  1418  				},
  1419  				Context: append(fields(), zap.Error(closeErr), zap.Duration("duration", 0)),
  1420  			},
  1421  		}
  1422  
  1423  		// We are only interested in the send, receive and stream close logs
  1424  		// log 1 - open stream
  1425  		// log 2 - send message
  1426  		// log 3 - receive message
  1427  		// log 4 - close stream
  1428  		gotLogs := getLogs(t, 4)[1:]
  1429  		assert.Equal(t, wantLogs, gotLogs)
  1430  	})
  1431  
  1432  	t.Run("EOF is a success with an error", func(t *testing.T) {
  1433  		ctx, cancel := context.WithCancel(context.Background())
  1434  		defer cancel()
  1435  
  1436  		clientStream, serverStream, finish, err := transporttest.MessagePipe(ctx, req)
  1437  		require.NoError(t, err)
  1438  
  1439  		var wg sync.WaitGroup
  1440  		wg.Add(1)
  1441  		go func() {
  1442  			finish(mw.HandleStream(serverStream, &fakeHandler{
  1443  				// send and receive messages in the handler
  1444  				handleStream: func(stream *transport.ServerStream) {
  1445  					// echo loop
  1446  					for {
  1447  						msg, err := stream.ReceiveMessage(ctx)
  1448  						if err == io.EOF {
  1449  							return
  1450  						}
  1451  						err = stream.SendMessage(ctx, msg)
  1452  						if err == io.EOF {
  1453  							return
  1454  						}
  1455  					}
  1456  				},
  1457  			}))
  1458  			wg.Done()
  1459  		}()
  1460  
  1461  		{
  1462  			err := clientStream.SendMessage(ctx, nil)
  1463  			require.NoError(t, err)
  1464  		}
  1465  
  1466  		{
  1467  			msg, err := clientStream.ReceiveMessage(ctx)
  1468  			require.NoError(t, err)
  1469  			assert.Nil(t, msg)
  1470  		}
  1471  
  1472  		require.NoError(t, clientStream.Close(ctx))
  1473  
  1474  		wg.Wait()
  1475  
  1476  		logFields := func(err error) []zapcore.Field {
  1477  			return newZapFields(
  1478  				zap.String("direction", string(_directionInbound)),
  1479  				zap.String("rpcType", "Streaming"),
  1480  				zap.Bool("successful", true),
  1481  				zap.Skip(), // context extractor
  1482  				zap.Error(err),
  1483  			)
  1484  		}
  1485  
  1486  		wantLogs := []observer.LoggedEntry{
  1487  			{
  1488  				// open stream
  1489  				Entry: zapcore.Entry{
  1490  					Message: _successStreamOpen,
  1491  				},
  1492  				Context: logFields(nil),
  1493  			},
  1494  			{
  1495  				// receive message
  1496  				Entry: zapcore.Entry{
  1497  					Message: _successfulStreamReceive,
  1498  				},
  1499  				Context: logFields(nil),
  1500  			},
  1501  			{
  1502  				// send message
  1503  				Entry: zapcore.Entry{
  1504  					Message: _successfulStreamSend,
  1505  				},
  1506  				Context: logFields(nil),
  1507  			},
  1508  			{
  1509  				// receive message (EOF)
  1510  				Entry: zapcore.Entry{
  1511  					Message: _successfulStreamReceive,
  1512  				},
  1513  				Context: logFields(io.EOF),
  1514  			},
  1515  			{
  1516  				// close stream
  1517  				Entry: zapcore.Entry{
  1518  					Message: _successStreamClose,
  1519  				},
  1520  				Context: append(logFields(nil), zap.Duration("duration", 0)),
  1521  			},
  1522  		}
  1523  
  1524  		// log 1 - open stream
  1525  		// log 2 - receive message
  1526  		// log 3 - send message
  1527  		// log 4 - receive message
  1528  		// log 5 - close stream
  1529  		gotLogs := getLogs(t, 5)
  1530  		assert.Equal(t, wantLogs, gotLogs)
  1531  	})
  1532  }
  1533  
  1534  func TestMiddlewareStreamingLoggingErrorWithServerClientConfiguration(t *testing.T) {
  1535  	defer stubTime()()
  1536  	req := &transport.StreamRequest{
  1537  		Meta: &transport.RequestMeta{
  1538  			Caller:          "caller",
  1539  			Service:         "service",
  1540  			Transport:       "transport",
  1541  			Encoding:        "raw",
  1542  			Procedure:       "procedure",
  1543  			Headers:         transport.NewHeaders().With("hello!", "goodbye!"),
  1544  			ShardKey:        "shard-key",
  1545  			RoutingKey:      "routing-key",
  1546  			RoutingDelegate: "routing-delegate",
  1547  		},
  1548  	}
  1549  
  1550  	// helper function to creating logging fields for assertion
  1551  	newZapFields := func(extraFields ...zapcore.Field) []zapcore.Field {
  1552  		fields := []zapcore.Field{
  1553  			zap.String("source", req.Meta.Caller),
  1554  			zap.String("dest", req.Meta.Service),
  1555  			zap.String("transport", req.Meta.Transport),
  1556  			zap.String("procedure", req.Meta.Procedure),
  1557  			zap.String("encoding", string(req.Meta.Encoding)),
  1558  			zap.String("routingKey", req.Meta.RoutingKey),
  1559  			zap.String("routingDelegate", req.Meta.RoutingDelegate),
  1560  		}
  1561  		return append(fields, extraFields...)
  1562  	}
  1563  
  1564  	// create middleware
  1565  	core, logs := observer.New(zapcore.DebugLevel)
  1566  	infoLevel := zapcore.InfoLevel
  1567  	warnLevel := zapcore.WarnLevel
  1568  	mw := NewMiddleware(Config{
  1569  		Logger:           zap.New(core),
  1570  		Scope:            metrics.New().Scope(),
  1571  		ContextExtractor: NewNopContextExtractor(),
  1572  		Levels: LevelsConfig{
  1573  			Default: DirectionalLevelsConfig{
  1574  				Success:     &infoLevel,
  1575  				ClientError: &warnLevel,
  1576  			},
  1577  		},
  1578  	})
  1579  
  1580  	// helper function to retrieve observed logs, asserting the expected number
  1581  	getLogs := func(t *testing.T, num int) []observer.LoggedEntry {
  1582  		logs := logs.TakeAll()
  1583  		require.Equal(t, num, len(logs), "expected exactly %d logs, got %v: %#v", num, len(logs), logs)
  1584  
  1585  		var entries []observer.LoggedEntry
  1586  		for _, e := range logs {
  1587  			// zero the time for easy comparisons
  1588  			e.Entry.Time = time.Time{}
  1589  			entries = append(entries, e)
  1590  		}
  1591  		return entries
  1592  	}
  1593  
  1594  	t.Run("error handler", func(t *testing.T) {
  1595  		tests := []struct {
  1596  			name      string
  1597  			err       error
  1598  			wantLevel zapcore.Level
  1599  		}{
  1600  			{
  1601  				name:      "client fault",
  1602  				err:       yarpcerrors.InvalidArgumentErrorf("client err"),
  1603  				wantLevel: zapcore.WarnLevel,
  1604  			},
  1605  			{
  1606  				name:      "server fault",
  1607  				err:       yarpcerrors.InternalErrorf("server err"),
  1608  				wantLevel: zapcore.ErrorLevel,
  1609  			},
  1610  			{
  1611  				name:      "unknown fault",
  1612  				err:       errors.New("unknown fault"),
  1613  				wantLevel: zapcore.ErrorLevel,
  1614  			},
  1615  		}
  1616  
  1617  		for _, tt := range tests {
  1618  			t.Run(tt.name, func(t *testing.T) {
  1619  				stream, err := transport.NewServerStream(&fakeStream{request: req})
  1620  				require.NoError(t, err)
  1621  
  1622  				err = mw.HandleStream(stream, &fakeHandler{err: tt.err})
  1623  				require.Error(t, err)
  1624  
  1625  				fields := newZapFields(
  1626  					zap.String("direction", string(_directionInbound)),
  1627  					zap.String("rpcType", "Streaming"),
  1628  					zap.Bool("successful", false),
  1629  					zap.Skip(), // context extractor
  1630  					zap.Error(tt.err),
  1631  					zap.Duration("duration", 0),
  1632  				)
  1633  
  1634  				wantLog := observer.LoggedEntry{
  1635  					Entry: zapcore.Entry{
  1636  						Message: _errorStreamClose,
  1637  						Level:   tt.wantLevel,
  1638  					},
  1639  					Context: fields,
  1640  				}
  1641  
  1642  				// The stream handler is only executed after a stream successfully connects
  1643  				// with a client. Therefore the first streaming log will always be
  1644  				// successful (tested in the previous subtest). We only care about the
  1645  				// stream termination so we retrieve the last log.
  1646  				//
  1647  				// log 1 - open stream
  1648  				// log 2 - close stream
  1649  				gotLog := getLogs(t, 2)[1]
  1650  				assert.Equal(t, wantLog, gotLog)
  1651  			})
  1652  		}
  1653  	})
  1654  }
  1655  
  1656  func TestMiddlewareMetrics(t *testing.T) {
  1657  	defer stubTime()()
  1658  	req := &transport.Request{
  1659  		Caller:    "caller",
  1660  		Service:   "service",
  1661  		Transport: "",
  1662  		Encoding:  "raw",
  1663  		Procedure: "procedure",
  1664  		Body:      strings.NewReader("body"),
  1665  	}
  1666  
  1667  	yErrAlreadyExists := yarpcerrors.CodeAlreadyExists
  1668  	yErrCodeUnknown := yarpcerrors.CodeUnknown
  1669  
  1670  	type failureTags struct {
  1671  		errorTag     string
  1672  		errorNameTag string
  1673  	}
  1674  
  1675  	type test struct {
  1676  		desc               string
  1677  		err                error             // downstream error
  1678  		applicationErr     bool              // downstream application error
  1679  		applicationErrName string            // downstream application error name
  1680  		applicationErrCode *yarpcerrors.Code // downstream application error code
  1681  		wantCalls          int
  1682  		wantSuccesses      int
  1683  		wantCallerFailures map[failureTags]int
  1684  		wantServerFailures map[failureTags]int
  1685  	}
  1686  
  1687  	tests := []test{
  1688  		{
  1689  			desc:          "no downstream errors",
  1690  			wantCalls:     1,
  1691  			wantSuccesses: 1,
  1692  		},
  1693  		{
  1694  			desc:          "invalid argument error",
  1695  			err:           yarpcerrors.Newf(yarpcerrors.CodeInvalidArgument, "test"),
  1696  			wantCalls:     1,
  1697  			wantSuccesses: 0,
  1698  			wantCallerFailures: map[failureTags]int{
  1699  				{
  1700  					errorTag:     yarpcerrors.CodeInvalidArgument.String(),
  1701  					errorNameTag: _notSet,
  1702  				}: 1,
  1703  			},
  1704  		},
  1705  		{
  1706  			desc:          "internal error",
  1707  			err:           yarpcerrors.Newf(yarpcerrors.CodeInternal, "test"),
  1708  			wantCalls:     1,
  1709  			wantSuccesses: 0,
  1710  			wantServerFailures: map[failureTags]int{
  1711  				{
  1712  					errorTag:     yarpcerrors.CodeInternal.String(),
  1713  					errorNameTag: _notSet,
  1714  				}: 1,
  1715  			},
  1716  		},
  1717  		{
  1718  			desc:          "unknown (unwrapped) error",
  1719  			err:           errors.New("test"),
  1720  			wantCalls:     1,
  1721  			wantSuccesses: 0,
  1722  			wantServerFailures: map[failureTags]int{
  1723  				{
  1724  					errorTag:     "unknown_internal_yarpc",
  1725  					errorNameTag: _notSet,
  1726  				}: 1,
  1727  			},
  1728  		},
  1729  		{
  1730  			desc:          "custom error code error",
  1731  			err:           yarpcerrors.Newf(yarpcerrors.Code(1000), "test"),
  1732  			wantCalls:     1,
  1733  			wantSuccesses: 0,
  1734  			wantServerFailures: map[failureTags]int{
  1735  				{
  1736  					errorTag:     "1000",
  1737  					errorNameTag: _notSet,
  1738  				}: 1,
  1739  			},
  1740  		},
  1741  		{
  1742  			desc:               "application error name with no code",
  1743  			wantCalls:          1,
  1744  			wantSuccesses:      0,
  1745  			applicationErr:     true,
  1746  			applicationErrName: "SomeError",
  1747  			wantCallerFailures: map[failureTags]int{
  1748  				{
  1749  					errorTag:     "application_error",
  1750  					errorNameTag: "SomeError",
  1751  				}: 1,
  1752  			},
  1753  		},
  1754  		{
  1755  			desc:               "application error name with YARPC code - caller failure",
  1756  			wantCalls:          1,
  1757  			wantSuccesses:      0,
  1758  			applicationErr:     true,
  1759  			applicationErrName: "SomeError",
  1760  			applicationErrCode: &yErrAlreadyExists,
  1761  			wantCallerFailures: map[failureTags]int{
  1762  				{
  1763  					errorTag:     "already-exists",
  1764  					errorNameTag: "SomeError",
  1765  				}: 1,
  1766  			},
  1767  		},
  1768  		{
  1769  			desc:               "application error name with YARPC code - server failure",
  1770  			wantCalls:          1,
  1771  			wantSuccesses:      0,
  1772  			applicationErr:     true,
  1773  			applicationErrName: "InternalServerPain",
  1774  			applicationErrCode: &yErrCodeUnknown,
  1775  			wantServerFailures: map[failureTags]int{
  1776  				{
  1777  					errorTag:     "unknown",
  1778  					errorNameTag: "InternalServerPain",
  1779  				}: 1,
  1780  			},
  1781  		},
  1782  		{
  1783  			desc:               "application error with YARPC code and empty name",
  1784  			wantCalls:          1,
  1785  			wantSuccesses:      0,
  1786  			applicationErr:     true,
  1787  			applicationErrName: "",
  1788  			applicationErrCode: &yErrAlreadyExists,
  1789  			wantCallerFailures: map[failureTags]int{
  1790  				{
  1791  					errorTag:     "already-exists",
  1792  					errorNameTag: _notSet,
  1793  				}: 1,
  1794  			},
  1795  		},
  1796  	}
  1797  
  1798  	newHandler := func(t test) fakeHandler {
  1799  		return fakeHandler{
  1800  			err:                t.err,
  1801  			applicationErr:     t.applicationErr,
  1802  			applicationErrName: t.applicationErrName,
  1803  			applicationErrCode: t.applicationErrCode,
  1804  		}
  1805  	}
  1806  
  1807  	newOutbound := func(t test) fakeOutbound {
  1808  		return fakeOutbound{
  1809  			err:                t.err,
  1810  			applicationErr:     t.applicationErr,
  1811  			applicationErrName: t.applicationErrName,
  1812  			applicationErrCode: t.applicationErrCode,
  1813  		}
  1814  	}
  1815  
  1816  	for _, tt := range tests {
  1817  		validate := func(mw *Middleware, direction string, rpcType transport.Type) {
  1818  			key, free := getKey(req, direction, rpcType)
  1819  			edge := mw.graph.getEdge(key)
  1820  			free()
  1821  			assert.EqualValues(t, tt.wantCalls, edge.calls.Load(), "expected calls mismatch")
  1822  			assert.EqualValues(t, tt.wantSuccesses, edge.successes.Load(), "expected successes mismatch")
  1823  			assert.EqualValues(t, 0, edge.panics.Load(), "expected panics mismatch")
  1824  			for failureTags, num := range tt.wantCallerFailures {
  1825  				assert.EqualValues(t, num, edge.callerFailures.MustGet(
  1826  					_error, failureTags.errorTag,
  1827  					_errorNameMetricsKey, failureTags.errorNameTag,
  1828  				).Load(), "expected caller failures mismatch")
  1829  			}
  1830  			for failureTags, num := range tt.wantServerFailures {
  1831  				assert.EqualValues(t, num, edge.serverFailures.MustGet(
  1832  					_error, failureTags.errorTag,
  1833  					_errorNameMetricsKey, failureTags.errorNameTag,
  1834  				).Load(), "expected server failures mismatch")
  1835  			}
  1836  		}
  1837  		t.Run(tt.desc+", unary inbound", func(t *testing.T) {
  1838  			mw := NewMiddleware(Config{
  1839  				Logger:           zap.NewNop(),
  1840  				Scope:            metrics.New().Scope(),
  1841  				ContextExtractor: NewNopContextExtractor(),
  1842  			})
  1843  			mw.Handle(
  1844  				context.Background(),
  1845  				req,
  1846  				&transporttest.FakeResponseWriter{},
  1847  				newHandler(tt),
  1848  			)
  1849  			validate(mw, string(_directionInbound), transport.Unary)
  1850  		})
  1851  		t.Run(tt.desc+", unary outbound", func(t *testing.T) {
  1852  			mw := NewMiddleware(Config{
  1853  				Logger:           zap.NewNop(),
  1854  				Scope:            metrics.New().Scope(),
  1855  				ContextExtractor: NewNopContextExtractor(),
  1856  			})
  1857  			mw.Call(context.Background(), req, newOutbound(tt))
  1858  			validate(mw, string(_directionOutbound), transport.Unary)
  1859  		})
  1860  	}
  1861  }
  1862  
  1863  // getKey gets the "key" that we will use to get an edge in the graph.  We use
  1864  // a separate function to recreate the logic because extracting it out in the
  1865  // main code could have performance implications.
  1866  func getKey(req *transport.Request, direction string, rpcType transport.Type) (key []byte, free func()) {
  1867  	d := digester.New()
  1868  	d.Add(req.Caller)
  1869  	d.Add(req.Service)
  1870  	d.Add(req.Transport)
  1871  	d.Add(string(req.Encoding))
  1872  	d.Add(req.Procedure)
  1873  	d.Add(req.RoutingKey)
  1874  	d.Add(req.RoutingDelegate)
  1875  	d.Add(direction)
  1876  	d.Add(rpcType.String())
  1877  	return d.Digest(), d.Free
  1878  }
  1879  
  1880  func TestUnaryInboundApplicationErrors(t *testing.T) {
  1881  	timeVal := time.Now()
  1882  	defer stubTimeWithTimeVal(timeVal)()
  1883  	ttl := time.Millisecond * 1000
  1884  	ctx, cancel := context.WithDeadline(context.Background(), timeVal.Add(ttl))
  1885  	defer cancel()
  1886  
  1887  	yErrAlreadyExists := yarpcerrors.CodeAlreadyExists
  1888  
  1889  	req := &transport.Request{
  1890  		Caller:          "caller",
  1891  		Service:         "service",
  1892  		Transport:       "",
  1893  		Encoding:        "raw",
  1894  		Procedure:       "procedure",
  1895  		ShardKey:        "shard01",
  1896  		RoutingKey:      "routing-key",
  1897  		RoutingDelegate: "routing-delegate",
  1898  		Body:            strings.NewReader("body"),
  1899  	}
  1900  
  1901  	expectedFields := []zapcore.Field{
  1902  		zap.String("source", req.Caller),
  1903  		zap.String("dest", req.Service),
  1904  		zap.String("transport", "unknown"),
  1905  		zap.String("procedure", req.Procedure),
  1906  		zap.String("encoding", string(req.Encoding)),
  1907  		zap.String("routingKey", req.RoutingKey),
  1908  		zap.String("routingDelegate", req.RoutingDelegate),
  1909  		zap.String("direction", string(_directionInbound)),
  1910  		zap.String("rpcType", "Unary"),
  1911  		zap.Duration("latency", 0),
  1912  		zap.Bool("successful", false),
  1913  		zap.Skip(),
  1914  		zap.Duration("timeout", ttl),
  1915  		zap.String("error", "application_error"),
  1916  		zap.String("errorCode", "already-exists"),
  1917  		zap.String("errorName", "SomeFakeError"),
  1918  	}
  1919  
  1920  	core, logs := observer.New(zap.DebugLevel)
  1921  	mw := NewMiddleware(Config{
  1922  		Logger:           zap.New(core),
  1923  		Scope:            metrics.New().Scope(),
  1924  		ContextExtractor: NewNopContextExtractor(),
  1925  	})
  1926  
  1927  	assert.NoError(t, mw.Handle(
  1928  		ctx,
  1929  		req,
  1930  		&transporttest.FakeResponseWriter{},
  1931  		fakeHandler{
  1932  			err:                nil,
  1933  			applicationErr:     true,
  1934  			applicationErrName: "SomeFakeError",
  1935  			applicationErrCode: &yErrAlreadyExists,
  1936  		},
  1937  	), "Unexpected transport error.")
  1938  
  1939  	expected := observer.LoggedEntry{
  1940  		Entry: zapcore.Entry{
  1941  			Level:   zapcore.ErrorLevel,
  1942  			Message: "Error handling inbound request.",
  1943  		},
  1944  		Context: expectedFields,
  1945  	}
  1946  	entries := logs.TakeAll()
  1947  	require.Equal(t, 1, len(entries), "Unexpected number of log entries written.")
  1948  	entry := entries[0]
  1949  	entry.Time = time.Time{}
  1950  	assert.Equal(t, expected, entry, "Unexpected log entry written.")
  1951  }
  1952  
  1953  func TestMiddlewareSuccessSnapshot(t *testing.T) {
  1954  	timeVal := time.Now()
  1955  	defer stubTimeWithTimeVal(timeVal)()
  1956  	ttlMs := int64(1000)
  1957  	ttl := time.Millisecond * time.Duration(1000)
  1958  	root := metrics.New()
  1959  	meter := root.Scope()
  1960  	mw := NewMiddleware(Config{
  1961  		Logger:           zap.NewNop(),
  1962  		Scope:            meter,
  1963  		ContextExtractor: NewNopContextExtractor(),
  1964  	})
  1965  
  1966  	buf := bufferpool.Get()
  1967  	defer bufferpool.Put(buf)
  1968  
  1969  	buf.Write([]byte("body"))
  1970  
  1971  	ctx, cancel := context.WithDeadline(context.Background(), timeVal.Add(ttl))
  1972  	defer cancel()
  1973  	err := mw.Handle(
  1974  		ctx,
  1975  		&transport.Request{
  1976  			Caller:          "caller",
  1977  			Service:         "service",
  1978  			Transport:       "",
  1979  			Encoding:        "raw",
  1980  			Procedure:       "procedure",
  1981  			ShardKey:        "sk",
  1982  			RoutingKey:      "rk",
  1983  			RoutingDelegate: "rd",
  1984  			Body:            buf,
  1985  			// overwrite with fixed value of 256MB to test large bucket.
  1986  			// large buffer body causes CI to timeout.
  1987  			BodySize: 1024 * 1024 * 256,
  1988  		},
  1989  		&transporttest.FakeResponseWriter{},
  1990  		fakeHandler{responseData: make([]byte, 1024*1024*16)},
  1991  	)
  1992  	assert.NoError(t, err, "Unexpected transport error.")
  1993  
  1994  	snap := root.Snapshot()
  1995  	tags := metrics.Tags{
  1996  		"dest":             "service",
  1997  		"direction":        "inbound",
  1998  		"transport":        "unknown",
  1999  		"encoding":         "raw",
  2000  		"procedure":        "procedure",
  2001  		"routing_delegate": "rd",
  2002  		"routing_key":      "rk",
  2003  		"rpc_type":         transport.Unary.String(),
  2004  		"source":           "caller",
  2005  	}
  2006  	want := &metrics.RootSnapshot{
  2007  		Counters: []metrics.Snapshot{
  2008  			{Name: "calls", Tags: tags, Value: 1},
  2009  			{Name: "panics", Tags: tags, Value: 0},
  2010  			{Name: "successes", Tags: tags, Value: 1},
  2011  		},
  2012  		Histograms: []metrics.HistogramSnapshot{
  2013  			{
  2014  				Name: "caller_failure_latency_ms",
  2015  				Tags: tags,
  2016  				Unit: time.Millisecond,
  2017  			},
  2018  			{
  2019  				Name:   "request_payload_size_bytes",
  2020  				Tags:   tags,
  2021  				Unit:   time.Millisecond,
  2022  				Values: []int64{268435456}, // 256MB
  2023  			},
  2024  			{
  2025  				Name:   "response_payload_size_bytes",
  2026  				Tags:   tags,
  2027  				Unit:   time.Millisecond,
  2028  				Values: []int64{16777216}, // 16MB
  2029  			},
  2030  			{
  2031  				Name: "server_failure_latency_ms",
  2032  				Tags: tags,
  2033  				Unit: time.Millisecond,
  2034  			},
  2035  			{
  2036  				Name:   "success_latency_ms",
  2037  				Tags:   tags,
  2038  				Unit:   time.Millisecond,
  2039  				Values: []int64{1},
  2040  			},
  2041  			{
  2042  				Name: "timeout_ttl_ms",
  2043  				Tags: tags,
  2044  				Unit: time.Millisecond,
  2045  			},
  2046  			{
  2047  				Name:   "ttl_ms",
  2048  				Tags:   tags,
  2049  				Unit:   time.Millisecond,
  2050  				Values: []int64{ttlMs},
  2051  			},
  2052  		},
  2053  	}
  2054  	assert.Equal(t, want, snap, "Unexpected snapshot of metrics.")
  2055  }
  2056  
  2057  func TestMiddlewareSuccessSnapshotWithTagsFiltered(t *testing.T) {
  2058  	timeVal := time.Now()
  2059  	defer stubTimeWithTimeVal(timeVal)()
  2060  	ttlMs := int64(1000)
  2061  	ttl := time.Millisecond * time.Duration(1000)
  2062  	root := metrics.New()
  2063  	meter := root.Scope()
  2064  	mw := NewMiddleware(Config{
  2065  		Logger:           zap.NewNop(),
  2066  		Scope:            meter,
  2067  		ContextExtractor: NewNopContextExtractor(),
  2068  		MetricTagsBlocklist: []string{
  2069  			"routing_delegate",
  2070  		},
  2071  	})
  2072  
  2073  	buf := bufferpool.Get()
  2074  	defer bufferpool.Put(buf)
  2075  
  2076  	buf.Write([]byte("body"))
  2077  
  2078  	ctx, cancel := context.WithDeadline(context.Background(), timeVal.Add(ttl))
  2079  	defer cancel()
  2080  	err := mw.Handle(
  2081  		ctx,
  2082  		&transport.Request{
  2083  			Caller:          "caller",
  2084  			Service:         "service",
  2085  			Transport:       "",
  2086  			Encoding:        "raw",
  2087  			Procedure:       "procedure",
  2088  			ShardKey:        "sk",
  2089  			RoutingKey:      "rk",
  2090  			RoutingDelegate: "rd",
  2091  			Body:            buf,
  2092  			BodySize:        buf.Len(),
  2093  		},
  2094  		&transporttest.FakeResponseWriter{},
  2095  		fakeHandler{responseData: []byte("test response")},
  2096  	)
  2097  	assert.NoError(t, err, "Unexpected transport error.")
  2098  
  2099  	snap := root.Snapshot()
  2100  	tags := metrics.Tags{
  2101  		"dest":             "service",
  2102  		"direction":        "inbound",
  2103  		"transport":        "unknown",
  2104  		"encoding":         "raw",
  2105  		"procedure":        "procedure",
  2106  		"routing_key":      "rk",
  2107  		"routing_delegate": "__dropped__",
  2108  		"rpc_type":         transport.Unary.String(),
  2109  		"source":           "caller",
  2110  	}
  2111  	want := &metrics.RootSnapshot{
  2112  		Counters: []metrics.Snapshot{
  2113  			{Name: "calls", Tags: tags, Value: 1},
  2114  			{Name: "panics", Tags: tags, Value: 0},
  2115  			{Name: "successes", Tags: tags, Value: 1},
  2116  		},
  2117  		Histograms: []metrics.HistogramSnapshot{
  2118  			{
  2119  				Name: "caller_failure_latency_ms",
  2120  				Tags: tags,
  2121  				Unit: time.Millisecond,
  2122  			},
  2123  			{
  2124  				Name:   "request_payload_size_bytes",
  2125  				Tags:   tags,
  2126  				Unit:   time.Millisecond,
  2127  				Values: []int64{4},
  2128  			},
  2129  			{
  2130  				Name:   "response_payload_size_bytes",
  2131  				Tags:   tags,
  2132  				Unit:   time.Millisecond,
  2133  				Values: []int64{16},
  2134  			},
  2135  			{
  2136  				Name: "server_failure_latency_ms",
  2137  				Tags: tags,
  2138  				Unit: time.Millisecond,
  2139  			},
  2140  			{
  2141  				Name:   "success_latency_ms",
  2142  				Tags:   tags,
  2143  				Unit:   time.Millisecond,
  2144  				Values: []int64{1},
  2145  			},
  2146  			{
  2147  				Name: "timeout_ttl_ms",
  2148  				Tags: tags,
  2149  				Unit: time.Millisecond,
  2150  			},
  2151  			{
  2152  				Name:   "ttl_ms",
  2153  				Tags:   tags,
  2154  				Unit:   time.Millisecond,
  2155  				Values: []int64{ttlMs},
  2156  			},
  2157  		},
  2158  	}
  2159  	assert.Equal(t, want, snap, "Unexpected snapshot of metrics.")
  2160  }
  2161  
  2162  func TestMiddlewareSuccessSnapshotForCall(t *testing.T) {
  2163  	timeVal := time.Now()
  2164  	defer stubTimeWithTimeVal(timeVal)()
  2165  	ttlMs := int64(1000)
  2166  	ttl := time.Millisecond * time.Duration(1000)
  2167  	root := metrics.New()
  2168  	meter := root.Scope()
  2169  	mw := NewMiddleware(Config{
  2170  		Logger:           zap.NewNop(),
  2171  		Scope:            meter,
  2172  		ContextExtractor: NewNopContextExtractor(),
  2173  	})
  2174  
  2175  	buf := bufferpool.Get()
  2176  	defer bufferpool.Put(buf)
  2177  
  2178  	buf.Write([]byte("body"))
  2179  
  2180  	ctx, cancel := context.WithDeadline(context.Background(), timeVal.Add(ttl))
  2181  	defer cancel()
  2182  	_, err := mw.Call(
  2183  		ctx,
  2184  		&transport.Request{
  2185  			Caller:          "caller",
  2186  			Service:         "service",
  2187  			Transport:       "",
  2188  			Encoding:        "raw",
  2189  			Procedure:       "procedure",
  2190  			ShardKey:        "sk",
  2191  			RoutingKey:      "rk",
  2192  			RoutingDelegate: "rd",
  2193  			Body:            buf,
  2194  			BodySize:        buf.Len(),
  2195  		},
  2196  		&fakeOutbound{body: []byte("body")},
  2197  	)
  2198  	assert.NoError(t, err, "Unexpected transport error.")
  2199  
  2200  	snap := root.Snapshot()
  2201  	tags := metrics.Tags{
  2202  		"dest":             "service",
  2203  		"direction":        "outbound",
  2204  		"transport":        "unknown",
  2205  		"encoding":         "raw",
  2206  		"procedure":        "procedure",
  2207  		"routing_delegate": "rd",
  2208  		"routing_key":      "rk",
  2209  		"rpc_type":         transport.Unary.String(),
  2210  		"source":           "caller",
  2211  	}
  2212  	want := &metrics.RootSnapshot{
  2213  		Counters: []metrics.Snapshot{
  2214  			{Name: "calls", Tags: tags, Value: 1},
  2215  			{Name: "panics", Tags: tags, Value: 0},
  2216  			{Name: "successes", Tags: tags, Value: 1},
  2217  		},
  2218  		Histograms: []metrics.HistogramSnapshot{
  2219  			{
  2220  				Name: "caller_failure_latency_ms",
  2221  				Tags: tags,
  2222  				Unit: time.Millisecond,
  2223  			},
  2224  			{
  2225  				Name:   "request_payload_size_bytes",
  2226  				Tags:   tags,
  2227  				Unit:   time.Millisecond,
  2228  				Values: []int64{4},
  2229  			},
  2230  			{
  2231  				Name:   "response_payload_size_bytes",
  2232  				Tags:   tags,
  2233  				Unit:   time.Millisecond,
  2234  				Values: []int64{4},
  2235  			},
  2236  			{
  2237  				Name: "server_failure_latency_ms",
  2238  				Tags: tags,
  2239  				Unit: time.Millisecond,
  2240  			},
  2241  			{
  2242  				Name:   "success_latency_ms",
  2243  				Tags:   tags,
  2244  				Unit:   time.Millisecond,
  2245  				Values: []int64{1},
  2246  			},
  2247  			{
  2248  				Name: "timeout_ttl_ms",
  2249  				Tags: tags,
  2250  				Unit: time.Millisecond,
  2251  			},
  2252  			{
  2253  				Name:   "ttl_ms",
  2254  				Tags:   tags,
  2255  				Unit:   time.Millisecond,
  2256  				Values: []int64{ttlMs},
  2257  			},
  2258  		},
  2259  	}
  2260  	assert.Equal(t, want, snap, "Unexpected snapshot of metrics.")
  2261  }
  2262  
  2263  func TestMiddlewareSuccessSnapshotForCallOnWay(t *testing.T) {
  2264  	timeVal := time.Now()
  2265  	defer stubTimeWithTimeVal(timeVal)()
  2266  	ttlMs := int64(1000)
  2267  	ttl := time.Millisecond * time.Duration(1000)
  2268  	root := metrics.New()
  2269  	meter := root.Scope()
  2270  	mw := NewMiddleware(Config{
  2271  		Logger:           zap.NewNop(),
  2272  		Scope:            meter,
  2273  		ContextExtractor: NewNopContextExtractor(),
  2274  	})
  2275  
  2276  	buf := bufferpool.Get()
  2277  	defer bufferpool.Put(buf)
  2278  
  2279  	buf.Write([]byte("body"))
  2280  
  2281  	ctx, cancel := context.WithDeadline(context.Background(), timeVal.Add(ttl))
  2282  	defer cancel()
  2283  	_, err := mw.CallOneway(
  2284  		ctx,
  2285  		&transport.Request{
  2286  			Caller:          "caller",
  2287  			Service:         "service",
  2288  			Transport:       "",
  2289  			Encoding:        "raw",
  2290  			Procedure:       "procedure",
  2291  			ShardKey:        "sk",
  2292  			RoutingKey:      "rk",
  2293  			RoutingDelegate: "rd",
  2294  			Body:            buf,
  2295  			BodySize:        buf.Len(),
  2296  		},
  2297  		&fakeOutbound{body: []byte("body")},
  2298  	)
  2299  	assert.NoError(t, err, "Unexpected transport error.")
  2300  
  2301  	snap := root.Snapshot()
  2302  	tags := metrics.Tags{
  2303  		"dest":             "service",
  2304  		"direction":        "outbound",
  2305  		"transport":        "unknown",
  2306  		"encoding":         "raw",
  2307  		"procedure":        "procedure",
  2308  		"routing_delegate": "rd",
  2309  		"routing_key":      "rk",
  2310  		"rpc_type":         transport.Oneway.String(),
  2311  		"source":           "caller",
  2312  	}
  2313  	want := &metrics.RootSnapshot{
  2314  		Counters: []metrics.Snapshot{
  2315  			{Name: "calls", Tags: tags, Value: 1},
  2316  			{Name: "panics", Tags: tags, Value: 0},
  2317  			{Name: "successes", Tags: tags, Value: 1},
  2318  		},
  2319  		Histograms: []metrics.HistogramSnapshot{
  2320  			{
  2321  				Name: "caller_failure_latency_ms",
  2322  				Tags: tags,
  2323  				Unit: time.Millisecond,
  2324  			},
  2325  			{
  2326  				Name:   "request_payload_size_bytes",
  2327  				Tags:   tags,
  2328  				Unit:   time.Millisecond,
  2329  				Values: []int64{4},
  2330  			},
  2331  			{
  2332  				Name: "response_payload_size_bytes",
  2333  				Tags: tags,
  2334  				Unit: time.Millisecond,
  2335  			},
  2336  			{
  2337  				Name: "server_failure_latency_ms",
  2338  				Tags: tags,
  2339  				Unit: time.Millisecond,
  2340  			},
  2341  			{
  2342  				Name:   "success_latency_ms",
  2343  				Tags:   tags,
  2344  				Unit:   time.Millisecond,
  2345  				Values: []int64{1},
  2346  			},
  2347  			{
  2348  				Name: "timeout_ttl_ms",
  2349  				Tags: tags,
  2350  				Unit: time.Millisecond,
  2351  			},
  2352  			{
  2353  				Name:   "ttl_ms",
  2354  				Tags:   tags,
  2355  				Unit:   time.Millisecond,
  2356  				Values: []int64{ttlMs},
  2357  			},
  2358  		},
  2359  	}
  2360  	assert.Equal(t, want, snap, "Unexpected snapshot of metrics.")
  2361  }
  2362  
  2363  func TestMiddlewareSuccessSnapshotForOneWay(t *testing.T) {
  2364  	timeVal := time.Now()
  2365  	defer stubTimeWithTimeVal(timeVal)()
  2366  	ttlMs := int64(1000)
  2367  	ttl := time.Millisecond * time.Duration(1000)
  2368  	root := metrics.New()
  2369  	meter := root.Scope()
  2370  	mw := NewMiddleware(Config{
  2371  		Logger:           zap.NewNop(),
  2372  		Scope:            meter,
  2373  		ContextExtractor: NewNopContextExtractor(),
  2374  	})
  2375  
  2376  	buf := bufferpool.Get()
  2377  	defer bufferpool.Put(buf)
  2378  
  2379  	buf.Write([]byte("body"))
  2380  
  2381  	ctx, cancel := context.WithDeadline(context.Background(), timeVal.Add(ttl))
  2382  	defer cancel()
  2383  	err := mw.HandleOneway(
  2384  		ctx,
  2385  		&transport.Request{
  2386  			Caller:          "caller",
  2387  			Service:         "service",
  2388  			Transport:       "",
  2389  			Encoding:        "raw",
  2390  			Procedure:       "procedure",
  2391  			ShardKey:        "sk",
  2392  			RoutingKey:      "rk",
  2393  			RoutingDelegate: "rd",
  2394  			Body:            buf,
  2395  			BodySize:        buf.Len(),
  2396  		},
  2397  		fakeHandler{responseData: []byte("test response")},
  2398  	)
  2399  	assert.NoError(t, err, "Unexpected transport error.")
  2400  
  2401  	snap := root.Snapshot()
  2402  	tags := metrics.Tags{
  2403  		"dest":             "service",
  2404  		"direction":        "inbound",
  2405  		"transport":        "unknown",
  2406  		"encoding":         "raw",
  2407  		"procedure":        "procedure",
  2408  		"routing_delegate": "rd",
  2409  		"routing_key":      "rk",
  2410  		"rpc_type":         transport.Oneway.String(),
  2411  		"source":           "caller",
  2412  	}
  2413  	want := &metrics.RootSnapshot{
  2414  		Counters: []metrics.Snapshot{
  2415  			{Name: "calls", Tags: tags, Value: 1},
  2416  			{Name: "panics", Tags: tags, Value: 0},
  2417  			{Name: "successes", Tags: tags, Value: 1},
  2418  		},
  2419  		Histograms: []metrics.HistogramSnapshot{
  2420  			{
  2421  				Name: "caller_failure_latency_ms",
  2422  				Tags: tags,
  2423  				Unit: time.Millisecond,
  2424  			},
  2425  			{
  2426  				Name:   "request_payload_size_bytes",
  2427  				Tags:   tags,
  2428  				Unit:   time.Millisecond,
  2429  				Values: []int64{4},
  2430  			},
  2431  			{
  2432  				Name: "response_payload_size_bytes",
  2433  				Tags: tags,
  2434  				Unit: time.Millisecond,
  2435  			},
  2436  			{
  2437  				Name: "server_failure_latency_ms",
  2438  				Tags: tags,
  2439  				Unit: time.Millisecond,
  2440  			},
  2441  			{
  2442  				Name:   "success_latency_ms",
  2443  				Tags:   tags,
  2444  				Unit:   time.Millisecond,
  2445  				Values: []int64{1},
  2446  			},
  2447  			{
  2448  				Name: "timeout_ttl_ms",
  2449  				Tags: tags,
  2450  				Unit: time.Millisecond,
  2451  			},
  2452  			{
  2453  				Name:   "ttl_ms",
  2454  				Tags:   tags,
  2455  				Unit:   time.Millisecond,
  2456  				Values: []int64{ttlMs},
  2457  			},
  2458  		},
  2459  	}
  2460  	assert.Equal(t, want, snap, "Unexpected snapshot of metrics.")
  2461  }
  2462  
  2463  func TestMiddlewareFailureSnapshot(t *testing.T) {
  2464  	defer stubTime()()
  2465  	root := metrics.New()
  2466  	meter := root.Scope()
  2467  	mw := NewMiddleware(Config{
  2468  		Logger:           zap.NewNop(),
  2469  		Scope:            meter,
  2470  		ContextExtractor: NewNopContextExtractor(),
  2471  	})
  2472  
  2473  	buf := bufferpool.Get()
  2474  	defer bufferpool.Put(buf)
  2475  
  2476  	buf.Write([]byte("test body"))
  2477  
  2478  	err := mw.Handle(
  2479  		context.Background(),
  2480  		&transport.Request{
  2481  			Caller:          "caller",
  2482  			Service:         "service",
  2483  			Transport:       "",
  2484  			Encoding:        "raw",
  2485  			Procedure:       "procedure",
  2486  			ShardKey:        "sk",
  2487  			RoutingKey:      "rk",
  2488  			RoutingDelegate: "rd",
  2489  			Body:            buf,
  2490  			BodySize:        buf.Len(),
  2491  		},
  2492  		&transporttest.FakeResponseWriter{},
  2493  		fakeHandler{err: fmt.Errorf("yuno"), applicationErr: false, responseData: []byte("error")},
  2494  	)
  2495  	assert.Error(t, err, "Expected transport error.")
  2496  
  2497  	snap := root.Snapshot()
  2498  	tags := metrics.Tags{
  2499  		"dest":             "service",
  2500  		"direction":        "inbound",
  2501  		"encoding":         "raw",
  2502  		"procedure":        "procedure",
  2503  		"routing_delegate": "rd",
  2504  		"routing_key":      "rk",
  2505  		"rpc_type":         transport.Unary.String(),
  2506  		"source":           "caller",
  2507  		"transport":        "unknown",
  2508  	}
  2509  	errorTags := metrics.Tags{
  2510  		"dest":             "service",
  2511  		"direction":        "inbound",
  2512  		"encoding":         "raw",
  2513  		"error":            "unknown_internal_yarpc",
  2514  		"error_name":       _notSet,
  2515  		"procedure":        "procedure",
  2516  		"routing_delegate": "rd",
  2517  		"routing_key":      "rk",
  2518  		"rpc_type":         transport.Unary.String(),
  2519  		"source":           "caller",
  2520  		"transport":        "unknown",
  2521  	}
  2522  	want := &metrics.RootSnapshot{
  2523  		Counters: []metrics.Snapshot{
  2524  			{Name: "calls", Tags: tags, Value: 1},
  2525  			{Name: "panics", Tags: tags, Value: 0},
  2526  			{Name: "server_failures", Tags: errorTags, Value: 1},
  2527  			{Name: "successes", Tags: tags, Value: 0},
  2528  		},
  2529  		Histograms: []metrics.HistogramSnapshot{
  2530  			{
  2531  				Name: "caller_failure_latency_ms",
  2532  				Tags: tags,
  2533  				Unit: time.Millisecond,
  2534  			},
  2535  			{
  2536  				Name:   "request_payload_size_bytes",
  2537  				Tags:   tags,
  2538  				Unit:   time.Millisecond,
  2539  				Values: []int64{16},
  2540  			},
  2541  			{
  2542  				Name: "response_payload_size_bytes",
  2543  				Tags: tags,
  2544  				Unit: time.Millisecond,
  2545  			},
  2546  			{
  2547  				Name:   "server_failure_latency_ms",
  2548  				Tags:   tags,
  2549  				Unit:   time.Millisecond,
  2550  				Values: []int64{1},
  2551  			},
  2552  			{
  2553  				Name: "success_latency_ms",
  2554  				Tags: tags,
  2555  				Unit: time.Millisecond,
  2556  			},
  2557  			{
  2558  				Name: "timeout_ttl_ms",
  2559  				Tags: tags,
  2560  				Unit: time.Millisecond,
  2561  			},
  2562  			{
  2563  				Name: "ttl_ms",
  2564  				Tags: tags,
  2565  				Unit: time.Millisecond,
  2566  			},
  2567  		},
  2568  	}
  2569  	assert.Equal(t, want, snap, "Unexpected snapshot of metrics.")
  2570  }
  2571  
  2572  func TestMiddlewareFailureWithDeadlineExceededSnapshot(t *testing.T) {
  2573  	timeVal := time.Now()
  2574  	defer stubTimeWithTimeVal(timeVal)()
  2575  
  2576  	ttlMs := int64(1000)
  2577  	ttl := time.Millisecond * time.Duration(ttlMs)
  2578  	root := metrics.New()
  2579  	meter := root.Scope()
  2580  	mw := NewMiddleware(Config{
  2581  		Logger:           zap.NewNop(),
  2582  		Scope:            meter,
  2583  		ContextExtractor: NewNopContextExtractor(),
  2584  	})
  2585  
  2586  	buf := bufferpool.Get()
  2587  	defer bufferpool.Put(buf)
  2588  
  2589  	buf.Write([]byte("test body"))
  2590  
  2591  	ctx, cancel := context.WithDeadline(context.Background(), timeVal.Add(ttl))
  2592  	defer cancel()
  2593  	err := mw.Handle(
  2594  		ctx,
  2595  		&transport.Request{
  2596  			Caller:          "caller",
  2597  			Service:         "service",
  2598  			Transport:       "",
  2599  			Encoding:        "raw",
  2600  			Procedure:       "procedure",
  2601  			ShardKey:        "sk",
  2602  			RoutingKey:      "rk",
  2603  			RoutingDelegate: "rd",
  2604  			Body:            buf,
  2605  			BodySize:        buf.Len(),
  2606  		},
  2607  		&transporttest.FakeResponseWriter{},
  2608  		fakeHandler{
  2609  			err:            yarpcerrors.DeadlineExceededErrorf("test deadline"),
  2610  			applicationErr: false,
  2611  			responseData:   []byte("deadline response"),
  2612  		},
  2613  	)
  2614  	assert.Error(t, err, "Expected transport error.")
  2615  
  2616  	snap := root.Snapshot()
  2617  	tags := metrics.Tags{
  2618  		"dest":             "service",
  2619  		"direction":        "inbound",
  2620  		"encoding":         "raw",
  2621  		"procedure":        "procedure",
  2622  		"routing_delegate": "rd",
  2623  		"routing_key":      "rk",
  2624  		"rpc_type":         transport.Unary.String(),
  2625  		"source":           "caller",
  2626  		"transport":        "unknown",
  2627  	}
  2628  	errorTags := metrics.Tags{
  2629  		"dest":             "service",
  2630  		"direction":        "inbound",
  2631  		"encoding":         "raw",
  2632  		"error":            "deadline-exceeded",
  2633  		"error_name":       _notSet,
  2634  		"procedure":        "procedure",
  2635  		"routing_delegate": "rd",
  2636  		"routing_key":      "rk",
  2637  		"rpc_type":         transport.Unary.String(),
  2638  		"source":           "caller",
  2639  		"transport":        "unknown",
  2640  	}
  2641  	want := &metrics.RootSnapshot{
  2642  		Counters: []metrics.Snapshot{
  2643  			{Name: "calls", Tags: tags, Value: 1},
  2644  			{Name: "panics", Tags: tags, Value: 0},
  2645  			{Name: "server_failures", Tags: errorTags, Value: 1},
  2646  			{Name: "successes", Tags: tags, Value: 0},
  2647  		},
  2648  		Histograms: []metrics.HistogramSnapshot{
  2649  			{
  2650  				Name: "caller_failure_latency_ms",
  2651  				Tags: tags,
  2652  				Unit: time.Millisecond,
  2653  			},
  2654  			{
  2655  				Name:   "request_payload_size_bytes",
  2656  				Tags:   tags,
  2657  				Unit:   time.Millisecond,
  2658  				Values: []int64{16},
  2659  			},
  2660  			{
  2661  				Name: "response_payload_size_bytes",
  2662  				Tags: tags,
  2663  				Unit: time.Millisecond,
  2664  			},
  2665  			{
  2666  				Name:   "server_failure_latency_ms",
  2667  				Tags:   tags,
  2668  				Unit:   time.Millisecond,
  2669  				Values: []int64{1},
  2670  			},
  2671  			{
  2672  				Name: "success_latency_ms",
  2673  				Tags: tags,
  2674  				Unit: time.Millisecond,
  2675  			},
  2676  			{
  2677  				Name:   "timeout_ttl_ms",
  2678  				Tags:   tags,
  2679  				Unit:   time.Millisecond,
  2680  				Values: []int64{ttlMs},
  2681  			},
  2682  			{
  2683  				Name:   "ttl_ms",
  2684  				Tags:   tags,
  2685  				Unit:   time.Millisecond,
  2686  				Values: []int64{ttlMs},
  2687  			},
  2688  		},
  2689  	}
  2690  	assert.Equal(t, want, snap, "Unexpected snapshot of metrics.")
  2691  }
  2692  
  2693  func TestApplicationErrorSnapShot(t *testing.T) {
  2694  	tests := []struct {
  2695  		name       string
  2696  		err        error
  2697  		errTag     string
  2698  		errNameTag string
  2699  		appErr     bool
  2700  		appErrName string
  2701  	}{
  2702  		{
  2703  			name:       "status", // eg error returned in transport middleware
  2704  			err:        yarpcerrors.Newf(yarpcerrors.CodeAlreadyExists, "foo exists!"),
  2705  			errTag:     "already-exists",
  2706  			errNameTag: _notSet,
  2707  		},
  2708  		{
  2709  			name:       "status and app error", // eg Protobuf handler returning yarpcerrors.Status
  2710  			err:        yarpcerrors.Newf(yarpcerrors.CodeAlreadyExists, "foo exists!"),
  2711  			errTag:     "already-exists",
  2712  			errNameTag: _notSet,
  2713  			appErr:     true,
  2714  		},
  2715  		{
  2716  			name:       "no status and app error", // eg Thrift exception
  2717  			err:        errors.New("foo-bar-baz"),
  2718  			errTag:     "application_error",
  2719  			errNameTag: "FakeError1",
  2720  			appErr:     true,
  2721  			appErrName: "FakeError1",
  2722  		},
  2723  	}
  2724  
  2725  	for _, tt := range tests {
  2726  		t.Run(tt.name, func(t *testing.T) {
  2727  			timeVal := time.Now()
  2728  			defer stubTimeWithTimeVal(timeVal)()
  2729  
  2730  			ttlMs := int64(1000)
  2731  			ttl := time.Millisecond * time.Duration(1000)
  2732  			root := metrics.New()
  2733  			meter := root.Scope()
  2734  			mw := NewMiddleware(Config{
  2735  				Logger: zap.NewNop(),
  2736  				Scope:  meter,
  2737  			})
  2738  			ctx, cancel := context.WithDeadline(context.Background(), timeVal.Add(ttl))
  2739  			defer cancel()
  2740  			err := mw.Handle(
  2741  				ctx,
  2742  				&transport.Request{
  2743  					Caller:          "caller",
  2744  					Service:         "service",
  2745  					Transport:       "",
  2746  					Encoding:        "raw",
  2747  					Procedure:       "procedure",
  2748  					ShardKey:        "sk",
  2749  					RoutingKey:      "rk",
  2750  					RoutingDelegate: "rd",
  2751  				},
  2752  				&transporttest.FakeResponseWriter{},
  2753  				fakeHandler{
  2754  					err:                tt.err,
  2755  					applicationErr:     tt.appErr,
  2756  					applicationErrName: tt.appErrName,
  2757  				},
  2758  			)
  2759  			require.Error(t, err)
  2760  
  2761  			snap := root.Snapshot()
  2762  			tags := metrics.Tags{
  2763  				"dest":             "service",
  2764  				"direction":        "inbound",
  2765  				"transport":        "unknown",
  2766  				"encoding":         "raw",
  2767  				"procedure":        "procedure",
  2768  				"routing_delegate": "rd",
  2769  				"routing_key":      "rk",
  2770  				"rpc_type":         transport.Unary.String(),
  2771  				"source":           "caller",
  2772  			}
  2773  			errorTags := metrics.Tags{
  2774  				"dest":             "service",
  2775  				"direction":        "inbound",
  2776  				"transport":        "unknown",
  2777  				"encoding":         "raw",
  2778  				"procedure":        "procedure",
  2779  				"routing_delegate": "rd",
  2780  				"routing_key":      "rk",
  2781  				"rpc_type":         transport.Unary.String(),
  2782  				"source":           "caller",
  2783  				"error":            tt.errTag,
  2784  				"error_name":       tt.errNameTag,
  2785  			}
  2786  			want := &metrics.RootSnapshot{
  2787  				Counters: []metrics.Snapshot{
  2788  					{Name: "caller_failures", Tags: errorTags, Value: 1},
  2789  					{Name: "calls", Tags: tags, Value: 1},
  2790  					{Name: "panics", Tags: tags, Value: 0},
  2791  					{Name: "successes", Tags: tags, Value: 0},
  2792  				},
  2793  				Histograms: []metrics.HistogramSnapshot{
  2794  					{
  2795  						Name:   "caller_failure_latency_ms",
  2796  						Tags:   tags,
  2797  						Unit:   time.Millisecond,
  2798  						Values: []int64{1},
  2799  					},
  2800  					{
  2801  						Name:   "request_payload_size_bytes",
  2802  						Tags:   tags,
  2803  						Unit:   time.Millisecond,
  2804  						Values: []int64{0},
  2805  					},
  2806  					{
  2807  						Name: "response_payload_size_bytes",
  2808  						Tags: tags,
  2809  						Unit: time.Millisecond,
  2810  					},
  2811  					{
  2812  						Name: "server_failure_latency_ms",
  2813  						Tags: tags,
  2814  						Unit: time.Millisecond,
  2815  					},
  2816  					{
  2817  						Name: "success_latency_ms",
  2818  						Tags: tags,
  2819  						Unit: time.Millisecond,
  2820  					},
  2821  					{
  2822  						Name: "timeout_ttl_ms",
  2823  						Tags: tags,
  2824  						Unit: time.Millisecond,
  2825  					},
  2826  					{
  2827  						Name:   "ttl_ms",
  2828  						Tags:   tags,
  2829  						Unit:   time.Millisecond,
  2830  						Values: []int64{ttlMs},
  2831  					},
  2832  				},
  2833  			}
  2834  			assert.Equal(t, want, snap, "Unexpected snapshot of metrics.")
  2835  		})
  2836  	}
  2837  }
  2838  
  2839  func TestUnaryInboundApplicationPanics(t *testing.T) {
  2840  	var err error
  2841  	root := metrics.New()
  2842  	scope := root.Scope()
  2843  	mw := NewMiddleware(Config{
  2844  		Logger:           zap.NewNop(),
  2845  		Scope:            scope,
  2846  		ContextExtractor: NewNopContextExtractor(),
  2847  	})
  2848  	newTags := func(direction directionName, withErr string) metrics.Tags {
  2849  		tags := metrics.Tags{
  2850  			"dest":             "service",
  2851  			"direction":        string(direction),
  2852  			"encoding":         "raw",
  2853  			"procedure":        "procedure",
  2854  			"routing_delegate": "rd",
  2855  			"routing_key":      "rk",
  2856  			"rpc_type":         transport.Unary.String(),
  2857  			"source":           "caller",
  2858  			"transport":        "unknown",
  2859  		}
  2860  		if withErr != "" {
  2861  			tags["error"] = withErr
  2862  			tags["error_name"] = "__not_set__"
  2863  		}
  2864  		return tags
  2865  	}
  2866  	tags := newTags(_directionInbound, "")
  2867  	errTags := newTags(_directionInbound, "internal")
  2868  
  2869  	t.Run("Test panic in Handle", func(t *testing.T) {
  2870  		// As our fake handler is mocked to panic in the call, test that the invocation panics
  2871  		assert.Panics(t, func() {
  2872  			err = mw.Handle(
  2873  				context.Background(),
  2874  				&transport.Request{
  2875  					Caller:          "caller",
  2876  					Service:         "service",
  2877  					Transport:       "",
  2878  					Encoding:        "raw",
  2879  					Procedure:       "procedure",
  2880  					ShardKey:        "sk",
  2881  					RoutingKey:      "rk",
  2882  					RoutingDelegate: "rd",
  2883  				},
  2884  				&transporttest.FakeResponseWriter{},
  2885  				fakeHandler{applicationPanic: true},
  2886  			)
  2887  		})
  2888  		require.NoError(t, err)
  2889  
  2890  		want := &metrics.RootSnapshot{
  2891  			Counters: []metrics.Snapshot{
  2892  				{Name: "calls", Tags: tags, Value: 1},
  2893  				{Name: "panics", Tags: tags, Value: 1},
  2894  				{Name: "server_failures", Tags: errTags, Value: 1},
  2895  				{Name: "successes", Tags: tags, Value: 0},
  2896  			},
  2897  			Histograms: []metrics.HistogramSnapshot{
  2898  				{
  2899  					Name: "caller_failure_latency_ms",
  2900  					Tags: tags,
  2901  					Unit: time.Millisecond,
  2902  				},
  2903  				{
  2904  					Name:   "request_payload_size_bytes",
  2905  					Tags:   tags,
  2906  					Unit:   time.Millisecond,
  2907  					Values: []int64{0},
  2908  				},
  2909  				{
  2910  					Name: "response_payload_size_bytes",
  2911  					Tags: tags,
  2912  					Unit: time.Millisecond,
  2913  				},
  2914  				{
  2915  					Name:   "server_failure_latency_ms",
  2916  					Tags:   tags,
  2917  					Unit:   time.Millisecond,
  2918  					Values: []int64{1},
  2919  				},
  2920  				{
  2921  					Name: "success_latency_ms",
  2922  					Tags: tags,
  2923  					Unit: time.Millisecond,
  2924  				},
  2925  				{
  2926  					Name: "timeout_ttl_ms",
  2927  					Tags: tags,
  2928  					Unit: time.Millisecond,
  2929  				},
  2930  				{
  2931  					Name: "ttl_ms",
  2932  					Tags: tags,
  2933  					Unit: time.Millisecond,
  2934  				},
  2935  			},
  2936  		}
  2937  		assert.Equal(t, want, root.Snapshot(), "unexpected metrics snapshot")
  2938  	})
  2939  }
  2940  
  2941  func TestUnaryOutboundApplicationPanics(t *testing.T) {
  2942  	var err error
  2943  	root := metrics.New()
  2944  	scope := root.Scope()
  2945  	mw := NewMiddleware(Config{
  2946  		Logger:           zap.NewNop(),
  2947  		Scope:            scope,
  2948  		ContextExtractor: NewNopContextExtractor(),
  2949  	})
  2950  	newTags := func(direction directionName, withErr string) metrics.Tags {
  2951  		tags := metrics.Tags{
  2952  			"dest":             "service",
  2953  			"direction":        string(direction),
  2954  			"encoding":         "raw",
  2955  			"procedure":        "procedure",
  2956  			"routing_delegate": "rd",
  2957  			"routing_key":      "rk",
  2958  			"rpc_type":         transport.Unary.String(),
  2959  			"source":           "caller",
  2960  			"transport":        "unknown",
  2961  		}
  2962  		if withErr != "" {
  2963  			tags["error"] = withErr
  2964  			tags["error_name"] = "__not_set__"
  2965  		}
  2966  		return tags
  2967  	}
  2968  	tags := newTags(_directionOutbound, "")
  2969  	errTags := newTags(_directionOutbound, "internal")
  2970  
  2971  	t.Run("Test panic in Call", func(t *testing.T) {
  2972  		// As our fake handler is mocked to panic in the call, test that the invocation panics
  2973  		assert.Panics(t, func() {
  2974  			_, err = mw.Call(
  2975  				context.Background(),
  2976  				&transport.Request{
  2977  					Caller:          "caller",
  2978  					Service:         "service",
  2979  					Transport:       "",
  2980  					Encoding:        "raw",
  2981  					Procedure:       "procedure",
  2982  					ShardKey:        "sk",
  2983  					RoutingKey:      "rk",
  2984  					RoutingDelegate: "rd",
  2985  				},
  2986  				fakeOutbound{applicationPanic: true},
  2987  			)
  2988  		})
  2989  		require.NoError(t, err)
  2990  
  2991  		want := &metrics.RootSnapshot{
  2992  			Counters: []metrics.Snapshot{
  2993  				{Name: "calls", Tags: tags, Value: 1},
  2994  				{Name: "panics", Tags: tags, Value: 1},
  2995  				{Name: "server_failures", Tags: errTags, Value: 1},
  2996  				{Name: "successes", Tags: tags, Value: 0},
  2997  			},
  2998  			Histograms: []metrics.HistogramSnapshot{
  2999  				{
  3000  					Name: "caller_failure_latency_ms",
  3001  					Tags: tags,
  3002  					Unit: time.Millisecond,
  3003  				},
  3004  				{
  3005  					Name:   "request_payload_size_bytes",
  3006  					Tags:   tags,
  3007  					Unit:   time.Millisecond,
  3008  					Values: []int64{0},
  3009  				},
  3010  				{
  3011  					Name: "response_payload_size_bytes",
  3012  					Tags: tags,
  3013  					Unit: time.Millisecond,
  3014  				},
  3015  				{
  3016  					Name:   "server_failure_latency_ms",
  3017  					Tags:   tags,
  3018  					Unit:   time.Millisecond,
  3019  					Values: []int64{1},
  3020  				},
  3021  				{
  3022  					Name: "success_latency_ms",
  3023  					Tags: tags,
  3024  					Unit: time.Millisecond,
  3025  				},
  3026  				{
  3027  					Name: "timeout_ttl_ms",
  3028  					Tags: tags,
  3029  					Unit: time.Millisecond,
  3030  				},
  3031  				{
  3032  					Name: "ttl_ms",
  3033  					Tags: tags,
  3034  					Unit: time.Millisecond,
  3035  				},
  3036  			},
  3037  		}
  3038  		assert.Equal(t, want, root.Snapshot(), "unexpected metrics snapshot")
  3039  	})
  3040  }
  3041  func TestOnewayInboundApplicationPanics(t *testing.T) {
  3042  	var err error
  3043  	root := metrics.New()
  3044  	scope := root.Scope()
  3045  	mw := NewMiddleware(Config{
  3046  		Logger:           zap.NewNop(),
  3047  		Scope:            scope,
  3048  		ContextExtractor: NewNopContextExtractor(),
  3049  	})
  3050  	newTags := func(direction directionName, withErr string) metrics.Tags {
  3051  		tags := metrics.Tags{
  3052  			"dest":             "service",
  3053  			"direction":        string(direction),
  3054  			"encoding":         "raw",
  3055  			"procedure":        "procedure",
  3056  			"routing_delegate": "rd",
  3057  			"routing_key":      "rk",
  3058  			"rpc_type":         transport.Oneway.String(),
  3059  			"source":           "caller",
  3060  			"transport":        "unknown",
  3061  		}
  3062  		if withErr != "" {
  3063  			tags["error"] = withErr
  3064  			tags["error_name"] = "__not_set__"
  3065  		}
  3066  		return tags
  3067  	}
  3068  	tags := newTags(_directionInbound, "")
  3069  	errTags := newTags(_directionInbound, "internal")
  3070  
  3071  	t.Run("Test panic in HandleOneway", func(t *testing.T) {
  3072  		// As our fake handler is mocked to panic in the call, test that the invocation panics
  3073  		assert.Panics(t, func() {
  3074  			err = mw.HandleOneway(
  3075  				context.Background(),
  3076  				&transport.Request{
  3077  					Caller:          "caller",
  3078  					Service:         "service",
  3079  					Transport:       "",
  3080  					Encoding:        "raw",
  3081  					Procedure:       "procedure",
  3082  					ShardKey:        "sk",
  3083  					RoutingKey:      "rk",
  3084  					RoutingDelegate: "rd",
  3085  				},
  3086  				fakeHandler{applicationPanic: true},
  3087  			)
  3088  		})
  3089  		require.NoError(t, err)
  3090  
  3091  		want := &metrics.RootSnapshot{
  3092  			Counters: []metrics.Snapshot{
  3093  				{Name: "calls", Tags: tags, Value: 1},
  3094  				{Name: "panics", Tags: tags, Value: 1},
  3095  				{Name: "server_failures", Tags: errTags, Value: 1},
  3096  				{Name: "successes", Tags: tags, Value: 0},
  3097  			},
  3098  			Histograms: []metrics.HistogramSnapshot{
  3099  				{
  3100  					Name: "caller_failure_latency_ms",
  3101  					Tags: tags,
  3102  					Unit: time.Millisecond,
  3103  				},
  3104  				{
  3105  					Name:   "request_payload_size_bytes",
  3106  					Tags:   tags,
  3107  					Unit:   time.Millisecond,
  3108  					Values: []int64{0},
  3109  				},
  3110  				{
  3111  					Name: "response_payload_size_bytes",
  3112  					Tags: tags,
  3113  					Unit: time.Millisecond,
  3114  				},
  3115  				{
  3116  					Name:   "server_failure_latency_ms",
  3117  					Tags:   tags,
  3118  					Unit:   time.Millisecond,
  3119  					Values: []int64{1},
  3120  				},
  3121  				{
  3122  					Name: "success_latency_ms",
  3123  					Tags: tags,
  3124  					Unit: time.Millisecond,
  3125  				},
  3126  				{
  3127  					Name: "timeout_ttl_ms",
  3128  					Tags: tags,
  3129  					Unit: time.Millisecond,
  3130  				},
  3131  				{
  3132  					Name: "ttl_ms",
  3133  					Tags: tags,
  3134  					Unit: time.Millisecond,
  3135  				},
  3136  			},
  3137  		}
  3138  		assert.Equal(t, want, root.Snapshot(), "unexpected metrics snapshot")
  3139  	})
  3140  }
  3141  
  3142  func TestOnewayOutboundApplicationPanics(t *testing.T) {
  3143  	var err error
  3144  	root := metrics.New()
  3145  	scope := root.Scope()
  3146  	mw := NewMiddleware(Config{
  3147  		Logger:           zap.NewNop(),
  3148  		Scope:            scope,
  3149  		ContextExtractor: NewNopContextExtractor(),
  3150  	})
  3151  	newTags := func(direction directionName, withErr string) metrics.Tags {
  3152  		tags := metrics.Tags{
  3153  			"dest":             "service",
  3154  			"direction":        string(direction),
  3155  			"encoding":         "raw",
  3156  			"procedure":        "procedure",
  3157  			"routing_delegate": "rd",
  3158  			"routing_key":      "rk",
  3159  			"rpc_type":         transport.Oneway.String(),
  3160  			"source":           "caller",
  3161  			"transport":        "unknown",
  3162  		}
  3163  		if withErr != "" {
  3164  			tags["error"] = withErr
  3165  			tags["error_name"] = "__not_set__"
  3166  		}
  3167  		return tags
  3168  	}
  3169  	tags := newTags(_directionOutbound, "")
  3170  	errTags := newTags(_directionOutbound, "internal")
  3171  
  3172  	t.Run("Test panic in CallOneway", func(t *testing.T) {
  3173  		// As our fake handler is mocked to panic in the call, test that the invocation panics
  3174  		assert.Panics(t, func() {
  3175  			_, err = mw.CallOneway(
  3176  				context.Background(),
  3177  				&transport.Request{
  3178  					Caller:          "caller",
  3179  					Service:         "service",
  3180  					Transport:       "",
  3181  					Encoding:        "raw",
  3182  					Procedure:       "procedure",
  3183  					ShardKey:        "sk",
  3184  					RoutingKey:      "rk",
  3185  					RoutingDelegate: "rd",
  3186  				},
  3187  				fakeOutbound{applicationPanic: true},
  3188  			)
  3189  		})
  3190  		require.NoError(t, err)
  3191  
  3192  		want := &metrics.RootSnapshot{
  3193  			Counters: []metrics.Snapshot{
  3194  				{Name: "calls", Tags: tags, Value: 1},
  3195  				{Name: "panics", Tags: tags, Value: 1},
  3196  				{Name: "server_failures", Tags: errTags, Value: 1},
  3197  				{Name: "successes", Tags: tags, Value: 0},
  3198  			},
  3199  			Histograms: []metrics.HistogramSnapshot{
  3200  				{
  3201  					Name: "caller_failure_latency_ms",
  3202  					Tags: tags,
  3203  					Unit: time.Millisecond,
  3204  				},
  3205  				{
  3206  					Name:   "request_payload_size_bytes",
  3207  					Tags:   tags,
  3208  					Unit:   time.Millisecond,
  3209  					Values: []int64{0},
  3210  				},
  3211  				{
  3212  					Name: "response_payload_size_bytes",
  3213  					Tags: tags,
  3214  					Unit: time.Millisecond,
  3215  				},
  3216  				{
  3217  					Name:   "server_failure_latency_ms",
  3218  					Tags:   tags,
  3219  					Unit:   time.Millisecond,
  3220  					Values: []int64{1},
  3221  				},
  3222  				{
  3223  					Name: "success_latency_ms",
  3224  					Tags: tags,
  3225  					Unit: time.Millisecond,
  3226  				},
  3227  				{
  3228  					Name: "timeout_ttl_ms",
  3229  					Tags: tags,
  3230  					Unit: time.Millisecond,
  3231  				},
  3232  				{
  3233  					Name: "ttl_ms",
  3234  					Tags: tags,
  3235  					Unit: time.Millisecond,
  3236  				},
  3237  			},
  3238  		}
  3239  		assert.Equal(t, want, root.Snapshot(), "unexpected metrics snapshot")
  3240  	})
  3241  }
  3242  
  3243  func TestStreamingInboundApplicationPanics(t *testing.T) {
  3244  	root := metrics.New()
  3245  	scope := root.Scope()
  3246  	mw := NewMiddleware(Config{
  3247  		Logger:           zap.NewNop(),
  3248  		Scope:            scope,
  3249  		ContextExtractor: NewNopContextExtractor(),
  3250  	})
  3251  	stream, err := transport.NewServerStream(&fakeStream{
  3252  		request: &transport.StreamRequest{
  3253  			Meta: &transport.RequestMeta{
  3254  				Caller:          "caller",
  3255  				Service:         "service",
  3256  				Transport:       "",
  3257  				Encoding:        "raw",
  3258  				Procedure:       "procedure",
  3259  				ShardKey:        "sk",
  3260  				RoutingKey:      "rk",
  3261  				RoutingDelegate: "rd",
  3262  			},
  3263  		},
  3264  	})
  3265  	require.NoError(t, err)
  3266  	newTags := func(direction directionName, withErr string) metrics.Tags {
  3267  		tags := metrics.Tags{
  3268  			"dest":             "service",
  3269  			"direction":        string(direction),
  3270  			"encoding":         "raw",
  3271  			"procedure":        "procedure",
  3272  			"routing_delegate": "rd",
  3273  			"routing_key":      "rk",
  3274  			"rpc_type":         transport.Streaming.String(),
  3275  			"source":           "caller",
  3276  			"transport":        "unknown",
  3277  		}
  3278  		if withErr != "" {
  3279  			tags["error"] = withErr
  3280  			tags["error_name"] = "__not_set__"
  3281  		}
  3282  		return tags
  3283  	}
  3284  	tags := newTags(_directionInbound, "")
  3285  	errTags := newTags(_directionInbound, "internal")
  3286  
  3287  	t.Run("Test panic in HandleStream", func(t *testing.T) {
  3288  		// As our fake handler is mocked to panic in the call, test that the invocation panics
  3289  		assert.Panics(t, func() {
  3290  			err = mw.HandleStream(stream, &fakeHandler{applicationPanic: true})
  3291  		})
  3292  		require.NoError(t, err)
  3293  
  3294  		want := &metrics.RootSnapshot{
  3295  			Counters: []metrics.Snapshot{
  3296  				{Name: "calls", Tags: tags, Value: 1},
  3297  				{Name: "panics", Tags: tags, Value: 1},
  3298  				{Name: "server_failures", Tags: errTags, Value: 1},
  3299  				{Name: "stream_receive_successes", Tags: tags, Value: 0},
  3300  				{Name: "stream_receives", Tags: tags, Value: 0},
  3301  				{Name: "stream_send_successes", Tags: tags, Value: 0},
  3302  				{Name: "stream_sends", Tags: tags, Value: 0},
  3303  				{Name: "successes", Tags: tags, Value: 1},
  3304  			},
  3305  			Gauges: []metrics.Snapshot{
  3306  				{Name: "streams_active", Tags: tags, Value: 0},
  3307  			},
  3308  			Histograms: []metrics.HistogramSnapshot{
  3309  				{
  3310  					Name:   "stream_duration_ms",
  3311  					Tags:   tags,
  3312  					Unit:   time.Millisecond,
  3313  					Values: []int64{1},
  3314  				},
  3315  				{
  3316  					Name: "stream_request_payload_size_bytes",
  3317  					Tags: tags,
  3318  					Unit: time.Millisecond,
  3319  				},
  3320  				{
  3321  					Name: "stream_response_payload_size_bytes",
  3322  					Tags: tags,
  3323  					Unit: time.Millisecond,
  3324  				},
  3325  			},
  3326  		}
  3327  		assert.Equal(t, want, root.Snapshot(), "unexpected metrics snapshot")
  3328  	})
  3329  }
  3330  
  3331  func TestStreamingOutboundApplicationPanics(t *testing.T) {
  3332  	root := metrics.New()
  3333  	scope := root.Scope()
  3334  	mw := NewMiddleware(Config{
  3335  		Logger:           zap.NewNop(),
  3336  		Scope:            scope,
  3337  		ContextExtractor: NewNopContextExtractor(),
  3338  	})
  3339  	stream, err := transport.NewServerStream(&fakeStream{
  3340  		request: &transport.StreamRequest{
  3341  			Meta: &transport.RequestMeta{
  3342  				Caller:          "caller",
  3343  				Service:         "service",
  3344  				Transport:       "",
  3345  				Encoding:        "raw",
  3346  				Procedure:       "procedure",
  3347  				ShardKey:        "sk",
  3348  				RoutingKey:      "rk",
  3349  				RoutingDelegate: "rd",
  3350  			},
  3351  		},
  3352  	})
  3353  	require.NoError(t, err)
  3354  	newTags := func(direction directionName, withErr string) metrics.Tags {
  3355  		tags := metrics.Tags{
  3356  			"dest":             "service",
  3357  			"direction":        string(direction),
  3358  			"encoding":         "raw",
  3359  			"procedure":        "procedure",
  3360  			"routing_delegate": "rd",
  3361  			"routing_key":      "rk",
  3362  			"rpc_type":         transport.Streaming.String(),
  3363  			"source":           "caller",
  3364  			"transport":        "unknown",
  3365  		}
  3366  		if withErr != "" {
  3367  			tags["error"] = withErr
  3368  			tags["error_name"] = "__not_set__"
  3369  		}
  3370  		return tags
  3371  	}
  3372  	tags := newTags(_directionOutbound, "")
  3373  	errTags := newTags(_directionOutbound, "internal")
  3374  
  3375  	t.Run("Test panic in CallStream", func(t *testing.T) {
  3376  		// As our fake handler is mocked to panic in the call, test that the invocation panics
  3377  		assert.Panics(t, func() {
  3378  			_, err = mw.CallStream(
  3379  				context.Background(),
  3380  				stream.Request(),
  3381  				fakeOutbound{applicationPanic: true})
  3382  		})
  3383  		require.NoError(t, err)
  3384  
  3385  		want := &metrics.RootSnapshot{
  3386  			Counters: []metrics.Snapshot{
  3387  				{Name: "calls", Tags: tags, Value: 0},
  3388  				{Name: "panics", Tags: tags, Value: 1},
  3389  				{Name: "server_failures", Tags: errTags, Value: 1},
  3390  				{Name: "stream_receive_successes", Tags: tags, Value: 0},
  3391  				{Name: "stream_receives", Tags: tags, Value: 0},
  3392  				{Name: "stream_send_successes", Tags: tags, Value: 0},
  3393  				{Name: "stream_sends", Tags: tags, Value: 0},
  3394  				{Name: "successes", Tags: tags, Value: 0},
  3395  			},
  3396  			Gauges: []metrics.Snapshot{
  3397  				{Name: "streams_active", Tags: tags, Value: -1},
  3398  			},
  3399  			Histograms: []metrics.HistogramSnapshot{
  3400  				{
  3401  					Name:   "stream_duration_ms",
  3402  					Tags:   tags,
  3403  					Unit:   time.Millisecond,
  3404  					Values: []int64{1},
  3405  				},
  3406  				{
  3407  					Name: "stream_request_payload_size_bytes",
  3408  					Tags: tags,
  3409  					Unit: time.Millisecond,
  3410  				},
  3411  				{
  3412  					Name: "stream_response_payload_size_bytes",
  3413  					Tags: tags,
  3414  					Unit: time.Millisecond,
  3415  				},
  3416  			},
  3417  		}
  3418  		assert.Equal(t, want, root.Snapshot(), "unexpected metrics snapshot")
  3419  	})
  3420  }
  3421  
  3422  func TestStreamingMetrics(t *testing.T) {
  3423  	defer stubTime()()
  3424  
  3425  	req := &transport.StreamRequest{
  3426  		Meta: &transport.RequestMeta{
  3427  			Caller:          "caller",
  3428  			Service:         "service",
  3429  			Transport:       "",
  3430  			Encoding:        "raw",
  3431  			Procedure:       "procedure",
  3432  			ShardKey:        "sk",
  3433  			RoutingKey:      "rk",
  3434  			RoutingDelegate: "rd",
  3435  		},
  3436  	}
  3437  
  3438  	newTags := func(direction directionName, withErr string, withCallerFailureErrName string) metrics.Tags {
  3439  		tags := metrics.Tags{
  3440  			"dest":             "service",
  3441  			"direction":        string(direction),
  3442  			"encoding":         "raw",
  3443  			"procedure":        "procedure",
  3444  			"routing_delegate": "rd",
  3445  			"routing_key":      "rk",
  3446  			"rpc_type":         transport.Streaming.String(),
  3447  			"source":           "caller",
  3448  			"transport":        "unknown",
  3449  		}
  3450  		if withErr != "" {
  3451  			tags["error"] = withErr
  3452  		}
  3453  		if withCallerFailureErrName != "" {
  3454  			tags[_errorNameMetricsKey] = withCallerFailureErrName
  3455  		}
  3456  		return tags
  3457  	}
  3458  
  3459  	t.Run("success server", func(t *testing.T) {
  3460  		root := metrics.New()
  3461  		scope := root.Scope()
  3462  		mw := NewMiddleware(Config{
  3463  			Logger:           zap.NewNop(),
  3464  			Scope:            scope,
  3465  			ContextExtractor: NewNopContextExtractor(),
  3466  		})
  3467  
  3468  		stream, err := transport.NewServerStream(
  3469  			&fakeStream{
  3470  				request: req,
  3471  				receiveMsg: &transport.StreamMessage{
  3472  					Body:     readCloser{bytes.NewReader([]byte("Foobar"))},
  3473  					BodySize: 6,
  3474  				},
  3475  			},
  3476  		)
  3477  		require.NoError(t, err)
  3478  		err = mw.HandleStream(stream, &fakeHandler{
  3479  			handleStream: func(stream *transport.ServerStream) {
  3480  				err := stream.SendMessage(
  3481  					context.Background(),
  3482  					&transport.StreamMessage{
  3483  						Body:     readCloser{bytes.NewReader([]byte("test"))},
  3484  						BodySize: 4,
  3485  					},
  3486  				)
  3487  				require.NoError(t, err)
  3488  				_, err = stream.ReceiveMessage(context.Background())
  3489  				require.NoError(t, err)
  3490  			}})
  3491  		require.NoError(t, err)
  3492  
  3493  		snap := root.Snapshot()
  3494  		tags := newTags(_directionInbound, "" /* withErr */, "" /* withCallerFailureErrName */)
  3495  
  3496  		// successful handshake, send, recv and close
  3497  		want := &metrics.RootSnapshot{
  3498  			Counters: []metrics.Snapshot{
  3499  				{Name: "calls", Tags: tags, Value: 1},
  3500  				{Name: "panics", Tags: tags, Value: 0},
  3501  				{Name: "stream_receive_successes", Tags: tags, Value: 1},
  3502  				{Name: "stream_receives", Tags: tags, Value: 1},
  3503  				{Name: "stream_send_successes", Tags: tags, Value: 1},
  3504  				{Name: "stream_sends", Tags: tags, Value: 1},
  3505  				{Name: "successes", Tags: tags, Value: 1},
  3506  			},
  3507  			Gauges: []metrics.Snapshot{
  3508  				{Name: "streams_active", Tags: tags, Value: 0}, // opened (+1) then closed (-1)
  3509  			},
  3510  			Histograms: []metrics.HistogramSnapshot{
  3511  				{Name: "stream_duration_ms", Tags: tags, Unit: time.Millisecond, Values: []int64{1}},
  3512  				{Name: "stream_request_payload_size_bytes", Tags: tags, Unit: time.Millisecond, Values: []int64{8}},
  3513  				{Name: "stream_response_payload_size_bytes", Tags: tags, Unit: time.Millisecond, Values: []int64{4}},
  3514  			},
  3515  		}
  3516  		assert.Equal(t, want, snap, "unexpected metrics snapshot")
  3517  	})
  3518  
  3519  	t.Run("error handler", func(t *testing.T) {
  3520  		tests := []struct {
  3521  			name       string
  3522  			err        error
  3523  			errName    string
  3524  			appErrName string
  3525  		}{
  3526  			{
  3527  				name:       "client fault",
  3528  				err:        yarpcerrors.InvalidArgumentErrorf("client err"),
  3529  				errName:    yarpcerrors.CodeInvalidArgument.String(),
  3530  				appErrName: _notSet,
  3531  			},
  3532  			{
  3533  				name:       "server fault",
  3534  				err:        yarpcerrors.InternalErrorf("server err"),
  3535  				errName:    yarpcerrors.CodeInternal.String(),
  3536  				appErrName: _notSet,
  3537  			},
  3538  			{
  3539  				name:       "unknown fault",
  3540  				err:        errors.New("unknown fault"),
  3541  				errName:    "unknown_internal_yarpc",
  3542  				appErrName: _notSet,
  3543  			},
  3544  		}
  3545  
  3546  		for _, tt := range tests {
  3547  			t.Run(tt.name, func(t *testing.T) {
  3548  				root := metrics.New()
  3549  				scope := root.Scope()
  3550  				mw := NewMiddleware(Config{
  3551  					Logger:           zap.NewNop(),
  3552  					Scope:            scope,
  3553  					ContextExtractor: NewNopContextExtractor(),
  3554  				})
  3555  
  3556  				stream, err := transport.NewServerStream(&fakeStream{request: req})
  3557  				require.NoError(t, err)
  3558  				err = mw.HandleStream(stream, &fakeHandler{err: tt.err})
  3559  				require.Error(t, err)
  3560  
  3561  				snap := root.Snapshot()
  3562  				successTags := newTags(_directionInbound, "", "")
  3563  				errTags := newTags(_directionInbound, tt.errName, tt.appErrName)
  3564  
  3565  				// so we don't have create a sorting implementation, manually place the
  3566  				// first two expected counter snapshots, based on the error fault.
  3567  				counters := make([]metrics.Snapshot, 0, 10)
  3568  				if yarpcerrors.GetFaultTypeFromCode(yarpcerrors.FromError(tt.err).Code()) == yarpcerrors.ClientFault {
  3569  					counters = append(counters,
  3570  						metrics.Snapshot{Name: "caller_failures", Tags: errTags, Value: 1},
  3571  						metrics.Snapshot{Name: "calls", Tags: successTags, Value: 1},
  3572  						metrics.Snapshot{Name: "panics", Tags: successTags, Value: 0})
  3573  
  3574  				} else {
  3575  					counters = append(counters,
  3576  						metrics.Snapshot{Name: "calls", Tags: successTags, Value: 1},
  3577  						metrics.Snapshot{Name: "panics", Tags: successTags, Value: 0},
  3578  						metrics.Snapshot{Name: "server_failures", Tags: errTags, Value: 1})
  3579  				}
  3580  
  3581  				want := &metrics.RootSnapshot{
  3582  					// only the failure vector counters will have an error value passed
  3583  					// into tags()
  3584  					Counters: append(counters,
  3585  						metrics.Snapshot{Name: "stream_receive_successes", Tags: successTags, Value: 0},
  3586  						metrics.Snapshot{Name: "stream_receives", Tags: successTags, Value: 0},
  3587  						metrics.Snapshot{Name: "stream_send_successes", Tags: successTags, Value: 0},
  3588  						metrics.Snapshot{Name: "stream_sends", Tags: successTags, Value: 0},
  3589  						metrics.Snapshot{Name: "successes", Tags: successTags, Value: 1}),
  3590  					Gauges: []metrics.Snapshot{
  3591  						{Name: "streams_active", Tags: successTags, Value: 0},
  3592  					},
  3593  					Histograms: []metrics.HistogramSnapshot{
  3594  						{Name: "stream_duration_ms", Tags: successTags, Unit: time.Millisecond, Values: []int64{1}},
  3595  						{Name: "stream_request_payload_size_bytes", Tags: successTags, Unit: time.Millisecond},
  3596  						{Name: "stream_response_payload_size_bytes", Tags: successTags, Unit: time.Millisecond},
  3597  					},
  3598  				}
  3599  				assert.Equal(t, want, snap, "unexpected metrics snapshot")
  3600  			})
  3601  		}
  3602  	})
  3603  
  3604  	t.Run("error server - send and recv", func(t *testing.T) {
  3605  		root := metrics.New()
  3606  		scope := root.Scope()
  3607  		mw := NewMiddleware(Config{
  3608  			Logger:           zap.NewNop(),
  3609  			Scope:            scope,
  3610  			ContextExtractor: NewNopContextExtractor(),
  3611  		})
  3612  
  3613  		sendErr := errors.New("send err")
  3614  		receiveErr := errors.New("receive err")
  3615  
  3616  		stream, err := transport.NewServerStream(&fakeStream{
  3617  			request:    req,
  3618  			sendErr:    sendErr,
  3619  			receiveErr: receiveErr,
  3620  		})
  3621  		require.NoError(t, err)
  3622  
  3623  		err = mw.HandleStream(stream, &fakeHandler{
  3624  			handleStream: func(stream *transport.ServerStream) {
  3625  				err := stream.SendMessage(context.Background(), nil /*message*/)
  3626  				require.Error(t, err)
  3627  				_, err = stream.ReceiveMessage(context.Background())
  3628  				require.Error(t, err)
  3629  			}})
  3630  		require.NoError(t, err)
  3631  
  3632  		snap := root.Snapshot()
  3633  		successTags := newTags(_directionInbound, "", "")
  3634  		errTags := newTags(_directionInbound, "unknown_internal_yarpc", "")
  3635  
  3636  		want := &metrics.RootSnapshot{
  3637  			Counters: []metrics.Snapshot{
  3638  				{Name: "calls", Tags: successTags, Value: 1},
  3639  				{Name: "panics", Tags: successTags, Value: 0},
  3640  				{Name: "stream_receive_failures", Tags: errTags, Value: 1},
  3641  				{Name: "stream_receive_successes", Tags: successTags, Value: 0},
  3642  				{Name: "stream_receives", Tags: successTags, Value: 1},
  3643  				{Name: "stream_send_failures", Tags: errTags, Value: 1},
  3644  				{Name: "stream_send_successes", Tags: successTags, Value: 0},
  3645  				{Name: "stream_sends", Tags: successTags, Value: 1},
  3646  				{Name: "successes", Tags: successTags, Value: 1},
  3647  			},
  3648  			Gauges: []metrics.Snapshot{
  3649  				{Name: "streams_active", Tags: successTags, Value: 0}, // opened (+1) then closed (-1)
  3650  			},
  3651  			Histograms: []metrics.HistogramSnapshot{
  3652  				{Name: "stream_duration_ms", Tags: successTags, Unit: time.Millisecond, Values: []int64{1}},
  3653  				{Name: "stream_request_payload_size_bytes", Tags: successTags, Unit: time.Millisecond},
  3654  				{Name: "stream_response_payload_size_bytes", Tags: successTags, Unit: time.Millisecond},
  3655  			},
  3656  		}
  3657  		assert.Equal(t, want, snap, "unexpected metrics snapshot")
  3658  	})
  3659  
  3660  	t.Run("success client", func(t *testing.T) {
  3661  		root := metrics.New()
  3662  		scope := root.Scope()
  3663  		mw := NewMiddleware(Config{
  3664  			Logger:           zap.NewNop(),
  3665  			Scope:            scope,
  3666  			ContextExtractor: NewNopContextExtractor(),
  3667  		})
  3668  
  3669  		stream, err := mw.CallStream(context.Background(), req, fakeOutbound{})
  3670  		require.NoError(t, err)
  3671  		err = stream.SendMessage(context.Background(), nil /* message */)
  3672  		require.NoError(t, err)
  3673  		_, err = stream.ReceiveMessage(context.Background())
  3674  		require.NoError(t, err)
  3675  		require.NoError(t, stream.Close(context.Background()))
  3676  
  3677  		snap := root.Snapshot()
  3678  		tags := newTags(_directionOutbound, "", "")
  3679  
  3680  		// successful handshake, send, recv and close
  3681  		want := &metrics.RootSnapshot{
  3682  			Counters: []metrics.Snapshot{
  3683  				{Name: "calls", Tags: tags, Value: 1},
  3684  				{Name: "panics", Tags: tags, Value: 0},
  3685  				{Name: "stream_receive_successes", Tags: tags, Value: 1},
  3686  				{Name: "stream_receives", Tags: tags, Value: 1},
  3687  				{Name: "stream_send_successes", Tags: tags, Value: 1},
  3688  				{Name: "stream_sends", Tags: tags, Value: 1},
  3689  				{Name: "successes", Tags: tags, Value: 1},
  3690  			},
  3691  			Gauges: []metrics.Snapshot{
  3692  				{Name: "streams_active", Tags: tags, Value: 0}, // opened (+1) then closed (-1)
  3693  			},
  3694  			Histograms: []metrics.HistogramSnapshot{
  3695  				{Name: "stream_duration_ms", Tags: tags, Unit: time.Millisecond, Values: []int64{1}},
  3696  				{Name: "stream_request_payload_size_bytes", Tags: tags, Unit: time.Millisecond},
  3697  				{Name: "stream_response_payload_size_bytes", Tags: tags, Unit: time.Millisecond},
  3698  			},
  3699  		}
  3700  		assert.Equal(t, want, snap, "unexpected metrics snapshot")
  3701  	})
  3702  
  3703  	t.Run("error client handshake", func(t *testing.T) {
  3704  		root := metrics.New()
  3705  		scope := root.Scope()
  3706  		mw := NewMiddleware(Config{
  3707  			Logger:           zap.NewNop(),
  3708  			Scope:            scope,
  3709  			ContextExtractor: NewNopContextExtractor(),
  3710  		})
  3711  
  3712  		clientErr := errors.New("client err")
  3713  		_, err := mw.CallStream(context.Background(), req, fakeOutbound{err: clientErr})
  3714  		require.Error(t, err)
  3715  
  3716  		snap := root.Snapshot()
  3717  		successTags := newTags(_directionOutbound, "", "")
  3718  		errTags := newTags(_directionOutbound, "unknown_internal_yarpc", _notSet)
  3719  
  3720  		want := &metrics.RootSnapshot{
  3721  			// only the failure vector counters will have an error value passed
  3722  			// into tags()
  3723  			Counters: []metrics.Snapshot{
  3724  				{Name: "calls", Tags: successTags, Value: 1},
  3725  				{Name: "panics", Tags: successTags, Value: 0},
  3726  				{Name: "server_failures", Tags: errTags, Value: 1},
  3727  				{Name: "stream_receive_successes", Tags: successTags, Value: 0},
  3728  				{Name: "stream_receives", Tags: successTags, Value: 0},
  3729  				{Name: "stream_send_successes", Tags: successTags, Value: 0},
  3730  				{Name: "stream_sends", Tags: successTags, Value: 0},
  3731  				{Name: "successes", Tags: successTags, Value: 0},
  3732  			},
  3733  			Gauges: []metrics.Snapshot{
  3734  				{Name: "streams_active", Tags: successTags, Value: 0},
  3735  			},
  3736  			Histograms: []metrics.HistogramSnapshot{
  3737  				{Name: "stream_duration_ms", Tags: successTags, Unit: time.Millisecond},
  3738  				{Name: "stream_request_payload_size_bytes", Tags: successTags, Unit: time.Millisecond},
  3739  				{Name: "stream_response_payload_size_bytes", Tags: successTags, Unit: time.Millisecond},
  3740  			},
  3741  		}
  3742  		assert.Equal(t, want, snap, "unexpected metrics snapshot")
  3743  	})
  3744  
  3745  	t.Run("error client - send recv close", func(t *testing.T) {
  3746  		root := metrics.New()
  3747  		scope := root.Scope()
  3748  		mw := NewMiddleware(Config{
  3749  			Logger:           zap.NewNop(),
  3750  			Scope:            scope,
  3751  			ContextExtractor: NewNopContextExtractor(),
  3752  		})
  3753  
  3754  		sendErr := errors.New("send err")
  3755  		receiveErr := errors.New("receive err")
  3756  		closeErr := errors.New("close err")
  3757  
  3758  		stream, err := mw.CallStream(context.Background(), req, fakeOutbound{
  3759  			stream: fakeStream{
  3760  				sendErr:    sendErr,
  3761  				receiveErr: receiveErr,
  3762  				closeErr:   closeErr,
  3763  			}})
  3764  		require.NoError(t, err)
  3765  
  3766  		err = stream.SendMessage(context.Background(), nil /* message */)
  3767  		require.Error(t, err)
  3768  		_, err = stream.ReceiveMessage(context.Background())
  3769  		require.Error(t, err)
  3770  		err = stream.Close(context.Background())
  3771  		require.Error(t, err)
  3772  
  3773  		snap := root.Snapshot()
  3774  		successTags := newTags(_directionOutbound, "", "")
  3775  		errTags := newTags(_directionOutbound, "unknown_internal_yarpc", "")
  3776  		serverFailureTags := newTags(_directionOutbound, "unknown_internal_yarpc", _notSet)
  3777  
  3778  		// successful handshake, send, recv and close
  3779  		want := &metrics.RootSnapshot{
  3780  			Counters: []metrics.Snapshot{
  3781  				{Name: "calls", Tags: successTags, Value: 1},
  3782  				{Name: "panics", Tags: successTags, Value: 0},
  3783  				{Name: "server_failures", Tags: serverFailureTags, Value: 1},
  3784  				{Name: "stream_receive_failures", Tags: errTags, Value: 1},
  3785  				{Name: "stream_receive_successes", Tags: successTags, Value: 0},
  3786  				{Name: "stream_receives", Tags: successTags, Value: 1},
  3787  				{Name: "stream_send_failures", Tags: errTags, Value: 1},
  3788  				{Name: "stream_send_successes", Tags: successTags, Value: 0},
  3789  				{Name: "stream_sends", Tags: successTags, Value: 1},
  3790  				{Name: "successes", Tags: successTags, Value: 1},
  3791  			},
  3792  			Gauges: []metrics.Snapshot{
  3793  				{Name: "streams_active", Tags: successTags, Value: 0}, // opened (+1) then closed (-1)
  3794  			},
  3795  			Histograms: []metrics.HistogramSnapshot{
  3796  				{Name: "stream_duration_ms", Tags: successTags, Unit: time.Millisecond, Values: []int64{1}},
  3797  				{Name: "stream_request_payload_size_bytes", Tags: successTags, Unit: time.Millisecond},
  3798  				{Name: "stream_response_payload_size_bytes", Tags: successTags, Unit: time.Millisecond},
  3799  			},
  3800  		}
  3801  		assert.Equal(t, want, snap, "unexpected metrics snapshot")
  3802  	})
  3803  }
  3804  
  3805  func TestNewWriterIsEmpty(t *testing.T) {
  3806  	code := yarpcerrors.CodeDataLoss
  3807  
  3808  	// set all fields on the response writer
  3809  	w := newWriter(&transporttest.FakeResponseWriter{})
  3810  	require.NotNil(t, w, "writer must not be nil")
  3811  
  3812  	w.SetApplicationError()
  3813  	w.SetApplicationErrorMeta(&transport.ApplicationErrorMeta{
  3814  		Details: "foo", Name: "bar", Code: &code,
  3815  	})
  3816  	w.free()
  3817  
  3818  	w = newWriter(nil /*transport.ResponseWriter*/)
  3819  	require.NotNil(t, w, "writer must not be nil")
  3820  	assert.Equal(t, writer{}, *w,
  3821  		"expected empty writer, fields were likely not cleared in the pool")
  3822  }
  3823  
  3824  type readCloser struct {
  3825  	*bytes.Reader
  3826  }
  3827  
  3828  func (r readCloser) Close() error {
  3829  	return nil
  3830  }
  3831  
  3832  func BenchmarkMiddlewareHandle(b *testing.B) {
  3833  	m := NewMiddleware(Config{
  3834  		Logger: zap.NewNop(),
  3835  	})
  3836  
  3837  	req := &transport.Request{
  3838  		Caller:          "caller",
  3839  		Service:         "service",
  3840  		Transport:       "",
  3841  		Encoding:        "raw",
  3842  		Procedure:       "procedure",
  3843  		Headers:         transport.NewHeaders().With("password", "super-secret"),
  3844  		ShardKey:        "shard01",
  3845  		RoutingKey:      "routing-key",
  3846  		RoutingDelegate: "routing-delegate",
  3847  		Body:            strings.NewReader("body"),
  3848  	}
  3849  
  3850  	b.ResetTimer()
  3851  
  3852  	for i := 0; i < b.N; i++ {
  3853  		err := m.Handle(context.Background(), req, &transporttest.FakeResponseWriter{}, &fakeHandler{})
  3854  		assert.NoError(b, err)
  3855  	}
  3856  }