github.com/shuguocloud/go-zero@v1.3.0/zrpc/internal/serverinterceptors/statinterceptor_test.go (about)

     1  package serverinterceptors
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/shuguocloud/go-zero/core/lang"
    11  	"github.com/shuguocloud/go-zero/core/stat"
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/peer"
    14  )
    15  
    16  func TestSetSlowThreshold(t *testing.T) {
    17  	assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
    18  	SetSlowThreshold(time.Second)
    19  	assert.Equal(t, time.Second, slowThreshold.Load())
    20  }
    21  
    22  func TestUnaryStatInterceptor(t *testing.T) {
    23  	metrics := stat.NewMetrics("mock")
    24  	interceptor := UnaryStatInterceptor(metrics)
    25  	_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
    26  		FullMethod: "/",
    27  	}, func(ctx context.Context, req interface{}) (interface{}, error) {
    28  		return nil, nil
    29  	})
    30  	assert.Nil(t, err)
    31  }
    32  
    33  func TestUnaryStatInterceptor_crash(t *testing.T) {
    34  	metrics := stat.NewMetrics("mock")
    35  	interceptor := UnaryStatInterceptor(metrics)
    36  	_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
    37  		FullMethod: "/",
    38  	}, func(ctx context.Context, req interface{}) (interface{}, error) {
    39  		panic("error")
    40  	})
    41  	assert.NotNil(t, err)
    42  }
    43  
    44  func TestLogDuration(t *testing.T) {
    45  	addrs, err := net.InterfaceAddrs()
    46  	assert.Nil(t, err)
    47  	assert.True(t, len(addrs) > 0)
    48  
    49  	tests := []struct {
    50  		name     string
    51  		ctx      context.Context
    52  		req      interface{}
    53  		duration time.Duration
    54  	}{
    55  		{
    56  			name: "normal",
    57  			ctx:  context.Background(),
    58  			req:  "foo",
    59  		},
    60  		{
    61  			name: "bad req",
    62  			ctx:  context.Background(),
    63  			req:  make(chan lang.PlaceholderType), // not marshalable
    64  		},
    65  		{
    66  			name:     "timeout",
    67  			ctx:      context.Background(),
    68  			req:      "foo",
    69  			duration: time.Second,
    70  		},
    71  		{
    72  			name: "timeout",
    73  			ctx: peer.NewContext(context.Background(), &peer.Peer{
    74  				Addr: addrs[0],
    75  			}),
    76  			req: "foo",
    77  		},
    78  	}
    79  
    80  	for _, test := range tests {
    81  		test := test
    82  		t.Run(test.name, func(t *testing.T) {
    83  			t.Parallel()
    84  
    85  			assert.NotPanics(t, func() {
    86  				logDuration(test.ctx, "foo", test.req, test.duration)
    87  			})
    88  		})
    89  	}
    90  }