go.temporal.io/server@v1.23.0/common/headers/caller_info_test.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package headers
    26  
    27  import (
    28  	"context"
    29  	"testing"
    30  
    31  	"github.com/stretchr/testify/require"
    32  	"github.com/stretchr/testify/suite"
    33  	"google.golang.org/grpc/metadata"
    34  )
    35  
    36  type (
    37  	callerInfoSuite struct {
    38  		*require.Assertions
    39  		suite.Suite
    40  	}
    41  )
    42  
    43  func TestCallerInfoSuite(t *testing.T) {
    44  	suite.Run(t, &callerInfoSuite{})
    45  }
    46  
    47  func (s *callerInfoSuite) SetupTest() {
    48  	s.Assertions = require.New(s.T())
    49  }
    50  
    51  func (s *callerInfoSuite) TestSetCallerName() {
    52  	ctx := context.Background()
    53  	info := GetCallerInfo(ctx)
    54  	s.Empty(info.CallerName)
    55  
    56  	ctx = SetCallerName(ctx, CallerNameSystem)
    57  	info = GetCallerInfo(ctx)
    58  	s.Equal(CallerNameSystem, info.CallerName)
    59  
    60  	ctx = SetCallerName(ctx, "")
    61  	info = GetCallerInfo(ctx)
    62  	s.Equal(CallerNameSystem, info.CallerName)
    63  
    64  	newCallerName := "new caller name"
    65  	ctx = SetCallerName(ctx, newCallerName)
    66  	info = GetCallerInfo(ctx)
    67  	s.Equal(newCallerName, info.CallerName)
    68  }
    69  
    70  func (s *callerInfoSuite) TestSetCallerType() {
    71  	ctx := context.Background()
    72  	info := GetCallerInfo(ctx)
    73  	s.Empty(info.CallerType)
    74  
    75  	ctx = SetCallerType(ctx, CallerTypeBackground)
    76  	info = GetCallerInfo(ctx)
    77  	s.Equal(CallerTypeBackground, info.CallerType)
    78  
    79  	ctx = SetCallerName(ctx, "")
    80  	info = GetCallerInfo(ctx)
    81  	s.Equal(CallerTypeBackground, info.CallerType)
    82  
    83  	ctx = SetCallerType(ctx, CallerTypeAPI)
    84  	info = GetCallerInfo(ctx)
    85  	s.Equal(CallerTypeAPI, info.CallerType)
    86  
    87  	ctx = SetCallerType(ctx, CallerTypePreemptable)
    88  	info = GetCallerInfo(ctx)
    89  	s.Equal(CallerTypePreemptable, info.CallerType)
    90  }
    91  
    92  func (s *callerInfoSuite) TestSetCallOrigin() {
    93  	ctx := context.Background()
    94  	info := GetCallerInfo(ctx)
    95  	s.Empty(info.CallOrigin)
    96  
    97  	initiation := "method name"
    98  	ctx = SetOrigin(ctx, initiation)
    99  	info = GetCallerInfo(ctx)
   100  	s.Equal(initiation, info.CallOrigin)
   101  
   102  	ctx = SetOrigin(ctx, "")
   103  	info = GetCallerInfo(ctx)
   104  	s.Equal(initiation, info.CallOrigin)
   105  
   106  	newCallOrigin := "another method name"
   107  	ctx = SetOrigin(ctx, newCallOrigin)
   108  	info = GetCallerInfo(ctx)
   109  	s.Equal(newCallOrigin, info.CallOrigin)
   110  }
   111  
   112  func (s *callerInfoSuite) TestSetCallerInfo_PreserveOtherValues() {
   113  	existingKey := "key"
   114  	existingValue := "value"
   115  	callerName := "callerName"
   116  	callerType := CallerTypeAPI
   117  	callOrigin := "methodName"
   118  
   119  	ctx := metadata.NewIncomingContext(
   120  		context.Background(),
   121  		metadata.Pairs(existingKey, existingValue),
   122  	)
   123  
   124  	ctx = SetCallerInfo(ctx, NewCallerInfo(callerName, callerType, callOrigin))
   125  
   126  	md, ok := metadata.FromIncomingContext(ctx)
   127  	s.True(ok)
   128  	s.Equal(existingValue, md.Get(existingKey)[0])
   129  	s.Equal(callerName, md.Get(callerNameHeaderName)[0])
   130  	s.Equal(callerType, md.Get(CallerTypeHeaderName)[0])
   131  	s.Equal(callOrigin, md.Get(callOriginHeaderName)[0])
   132  	s.Len(md, 4)
   133  }
   134  
   135  func (s *callerInfoSuite) TestSetCallerInfo_NoExistingCallerInfo() {
   136  	callerName := CallerNameSystem
   137  	callerType := CallerTypeAPI
   138  	callOrigin := "methodName"
   139  
   140  	ctx := SetCallerInfo(context.Background(), CallerInfo{
   141  		CallerName: callerName,
   142  		CallerType: callerType,
   143  		CallOrigin: callOrigin,
   144  	})
   145  
   146  	md, ok := metadata.FromIncomingContext(ctx)
   147  	s.True(ok)
   148  	s.Equal(callerName, md.Get(callerNameHeaderName)[0])
   149  	s.Equal(callerType, md.Get(CallerTypeHeaderName)[0])
   150  	s.Equal(callOrigin, md.Get(callOriginHeaderName)[0])
   151  	s.Len(md, 3)
   152  }
   153  
   154  func (s *callerInfoSuite) TestSetCallerInfo_WithExistingCallerInfo() {
   155  	callerName := CallerNameSystem
   156  	callerType := CallerTypeBackground
   157  	callOrigin := "methodName"
   158  
   159  	ctx := SetCallerName(context.Background(), callerName)
   160  	ctx = SetCallerType(ctx, CallerTypeAPI)
   161  	ctx = SetOrigin(ctx, callOrigin)
   162  
   163  	ctx = SetCallerInfo(ctx, CallerInfo{
   164  		CallerName: "",
   165  		CallerType: callerType,
   166  		CallOrigin: "",
   167  	})
   168  
   169  	md, ok := metadata.FromIncomingContext(ctx)
   170  	s.True(ok)
   171  	s.Equal(callerName, md.Get(callerNameHeaderName)[0])
   172  	s.Equal(callerType, md.Get(CallerTypeHeaderName)[0])
   173  	s.Equal(callOrigin, md.Get(callOriginHeaderName)[0])
   174  	s.Len(md, 3)
   175  }
   176  
   177  func (s *callerInfoSuite) TestSetCallerInfo_WithPartialCallerInfo() {
   178  	callerName := CallerNameSystem
   179  	callerType := CallerTypeAPI
   180  
   181  	ctx := SetCallerType(context.Background(), callerType)
   182  
   183  	ctx = SetCallerInfo(ctx, CallerInfo{
   184  		CallerName: callerName,
   185  	})
   186  
   187  	md, ok := metadata.FromIncomingContext(ctx)
   188  	s.True(ok)
   189  	s.Equal(callerName, md.Get(callerNameHeaderName)[0])
   190  	s.Equal(callerType, md.Get(CallerTypeHeaderName)[0])
   191  	s.Empty(md.Get(callOriginHeaderName))
   192  	s.Len(md, 2)
   193  }