go.temporal.io/server@v1.23.0/common/rpc/interceptor/caller_info_test.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package interceptor
    26  
    27  import (
    28  	"context"
    29  	"testing"
    30  
    31  	"github.com/golang/mock/gomock"
    32  	"github.com/stretchr/testify/require"
    33  	"github.com/stretchr/testify/suite"
    34  	"go.temporal.io/api/workflowservice/v1"
    35  	"google.golang.org/grpc"
    36  
    37  	"go.temporal.io/server/common/headers"
    38  	"go.temporal.io/server/common/namespace"
    39  )
    40  
    41  type (
    42  	callerInfoSuite struct {
    43  		suite.Suite
    44  		*require.Assertions
    45  
    46  		controller   *gomock.Controller
    47  		mockRegistry *namespace.MockRegistry
    48  
    49  		interceptor *CallerInfoInterceptor
    50  	}
    51  )
    52  
    53  func TestCallerInfoSuite(t *testing.T) {
    54  	s := new(callerInfoSuite)
    55  	suite.Run(t, s)
    56  }
    57  
    58  func (s *callerInfoSuite) SetupTest() {
    59  	s.Assertions = require.New(s.T())
    60  
    61  	s.controller = gomock.NewController(s.T())
    62  	s.mockRegistry = namespace.NewMockRegistry(s.controller)
    63  
    64  	s.interceptor = NewCallerInfoInterceptor(s.mockRegistry)
    65  }
    66  
    67  func (s *callerInfoSuite) TearDownSuite() {
    68  	s.controller.Finish()
    69  }
    70  
    71  func (s *callerInfoSuite) TestIntercept_CallerName() {
    72  	// testNamespaceID := namespace.NewID()
    73  	testNamespaceName := namespace.Name("test-namespace")
    74  	// s.mockRegistry.EXPECT().GetNamespaceName(testNamespaceID).Return(testNamespaceName, nil).AnyTimes()
    75  	s.mockRegistry.EXPECT().GetNamespace(gomock.Any()).Return(nil, nil).AnyTimes()
    76  
    77  	testCases := []struct {
    78  		setupIncomingCtx   func() context.Context
    79  		request            interface{}
    80  		expectedCallerName string
    81  	}{
    82  		{
    83  			// test context with no caller info
    84  			setupIncomingCtx: func() context.Context {
    85  				return context.Background()
    86  			},
    87  			request: &workflowservice.StartWorkflowExecutionRequest{
    88  				Namespace: testNamespaceName.String(),
    89  			},
    90  			expectedCallerName: testNamespaceName.String(),
    91  		},
    92  		{
    93  			// test context with caller type but no caller name
    94  			setupIncomingCtx: func() context.Context {
    95  				return headers.SetCallerType(context.Background(), headers.CallerTypeBackground)
    96  			},
    97  			request: &workflowservice.StartWorkflowExecutionRequest{
    98  				Namespace: testNamespaceName.String(),
    99  			},
   100  			expectedCallerName: testNamespaceName.String(),
   101  		},
   102  		{
   103  			// test context with caller name
   104  			setupIncomingCtx: func() context.Context {
   105  				return headers.SetCallerName(context.Background(), headers.CallerNameSystem)
   106  			},
   107  			request: &workflowservice.StartWorkflowExecutionRequest{
   108  				Namespace: testNamespaceName.String(),
   109  			},
   110  			expectedCallerName: headers.CallerNameSystem,
   111  		},
   112  		{
   113  			// test context with empty caller name
   114  			setupIncomingCtx: func() context.Context {
   115  				return headers.SetCallerName(context.Background(), "")
   116  			},
   117  			request: &workflowservice.StartWorkflowExecutionRequest{
   118  				Namespace: testNamespaceName.String(),
   119  			},
   120  			expectedCallerName: testNamespaceName.String(),
   121  		},
   122  	}
   123  
   124  	for _, testCase := range testCases {
   125  		ctx := testCase.setupIncomingCtx()
   126  
   127  		var resultingCtx context.Context
   128  		_, err := s.interceptor.Intercept(
   129  			ctx,
   130  			testCase.request,
   131  			&grpc.UnaryServerInfo{},
   132  			func(ctx context.Context, req interface{}) (interface{}, error) {
   133  				resultingCtx = ctx
   134  				return nil, nil
   135  			},
   136  		)
   137  		s.NoError(err)
   138  
   139  		actualCallerName := headers.GetCallerInfo(resultingCtx).CallerName
   140  		s.Equal(testCase.expectedCallerName, actualCallerName)
   141  	}
   142  }
   143  
   144  func (s *callerInfoSuite) TestIntercept_CallerType() {
   145  	s.mockRegistry.EXPECT().GetNamespace(gomock.Any()).Return(nil, nil).AnyTimes()
   146  
   147  	testCases := []struct {
   148  		setupIncomingCtx   func() context.Context
   149  		request            interface{}
   150  		expectedCallerType string
   151  	}{
   152  		{
   153  			// test context with no caller info
   154  			setupIncomingCtx: func() context.Context {
   155  				return context.Background()
   156  			},
   157  			request:            &workflowservice.StartWorkflowExecutionRequest{},
   158  			expectedCallerType: headers.CallerTypeAPI,
   159  		},
   160  		{
   161  			// test context with caller name but no caller type
   162  			setupIncomingCtx: func() context.Context {
   163  				return headers.SetCallerName(context.Background(), "test-namespace")
   164  			},
   165  			request:            &workflowservice.StartWorkflowExecutionRequest{},
   166  			expectedCallerType: headers.CallerTypeAPI,
   167  		},
   168  		{
   169  			// test context with caller type
   170  			setupIncomingCtx: func() context.Context {
   171  				return headers.SetCallerType(context.Background(), headers.CallerTypeBackground)
   172  			},
   173  			request:            &workflowservice.StartWorkflowExecutionRequest{},
   174  			expectedCallerType: headers.CallerTypeBackground,
   175  		},
   176  		{
   177  			// test context with empty caller type
   178  			setupIncomingCtx: func() context.Context {
   179  				return headers.SetCallerType(context.Background(), "")
   180  			},
   181  			request:            &workflowservice.StartWorkflowExecutionRequest{},
   182  			expectedCallerType: headers.CallerTypeAPI,
   183  		},
   184  	}
   185  
   186  	for _, testCase := range testCases {
   187  		ctx := testCase.setupIncomingCtx()
   188  
   189  		var resultingCtx context.Context
   190  		_, err := s.interceptor.Intercept(
   191  			ctx,
   192  			testCase.request,
   193  			&grpc.UnaryServerInfo{},
   194  			func(ctx context.Context, req interface{}) (interface{}, error) {
   195  				resultingCtx = ctx
   196  				return nil, nil
   197  			},
   198  		)
   199  		s.NoError(err)
   200  
   201  		actualCallerType := headers.GetCallerInfo(resultingCtx).CallerType
   202  		s.Equal(testCase.expectedCallerType, actualCallerType)
   203  	}
   204  }
   205  
   206  func (s *callerInfoSuite) TestIntercept_CallOrigin() {
   207  	method := "startWorkflowExecutionRequest"
   208  	serverInfo := &grpc.UnaryServerInfo{
   209  		FullMethod: "/temporal/" + method,
   210  	}
   211  	s.mockRegistry.EXPECT().GetNamespace(gomock.Any()).Return(nil, nil).AnyTimes()
   212  
   213  	testCases := []struct {
   214  		setupIncomingCtx   func() context.Context
   215  		request            interface{}
   216  		expectedCallOrigin string
   217  	}{
   218  		{
   219  			// test context with no caller info
   220  			setupIncomingCtx: func() context.Context {
   221  				return context.Background()
   222  			},
   223  			request:            &workflowservice.StartWorkflowExecutionRequest{},
   224  			expectedCallOrigin: method,
   225  		},
   226  		{
   227  			// test context with api caller type but no call initiation
   228  			setupIncomingCtx: func() context.Context {
   229  				return headers.SetCallerName(context.Background(), "test-namespace")
   230  			},
   231  			request:            &workflowservice.StartWorkflowExecutionRequest{},
   232  			expectedCallOrigin: method,
   233  		},
   234  		{
   235  			// test context with background caller type but no call initiation
   236  			setupIncomingCtx: func() context.Context {
   237  				return headers.SetCallerInfo(context.Background(), headers.SystemBackgroundCallerInfo)
   238  			},
   239  			request:            &workflowservice.StartWorkflowExecutionRequest{},
   240  			expectedCallOrigin: "",
   241  		},
   242  		{
   243  			// test context with call initiation
   244  			setupIncomingCtx: func() context.Context {
   245  				return headers.SetOrigin(context.Background(), "test-method")
   246  			},
   247  			request:            &workflowservice.StartWorkflowExecutionRequest{},
   248  			expectedCallOrigin: "test-method",
   249  		},
   250  		{
   251  			// test context with empty call initiation
   252  			setupIncomingCtx: func() context.Context {
   253  				return headers.SetOrigin(context.Background(), "")
   254  			},
   255  			request:            &workflowservice.StartWorkflowExecutionRequest{},
   256  			expectedCallOrigin: method,
   257  		},
   258  	}
   259  
   260  	for _, testCase := range testCases {
   261  		ctx := testCase.setupIncomingCtx()
   262  
   263  		var resultingCtx context.Context
   264  		_, err := s.interceptor.Intercept(
   265  			ctx,
   266  			testCase.request,
   267  			serverInfo,
   268  			func(ctx context.Context, req interface{}) (interface{}, error) {
   269  				resultingCtx = ctx
   270  				return nil, nil
   271  			},
   272  		)
   273  		s.NoError(err)
   274  
   275  		actualCallOrigin := headers.GetCallerInfo(resultingCtx).CallOrigin
   276  		s.Equal(testCase.expectedCallOrigin, actualCallOrigin)
   277  	}
   278  }