go.temporal.io/server@v1.23.0/common/headers/caller_info_test.go (about) 1 // The MIT License 2 // 3 // Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. 4 // 5 // Copyright (c) 2020 Uber Technologies, Inc. 6 // 7 // Permission is hereby granted, free of charge, to any person obtaining a copy 8 // of this software and associated documentation files (the "Software"), to deal 9 // in the Software without restriction, including without limitation the rights 10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 // copies of the Software, and to permit persons to whom the Software is 12 // furnished to do so, subject to the following conditions: 13 // 14 // The above copyright notice and this permission notice shall be included in 15 // all copies or substantial portions of the Software. 16 // 17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 // THE SOFTWARE. 24 25 package headers 26 27 import ( 28 "context" 29 "testing" 30 31 "github.com/stretchr/testify/require" 32 "github.com/stretchr/testify/suite" 33 "google.golang.org/grpc/metadata" 34 ) 35 36 type ( 37 callerInfoSuite struct { 38 *require.Assertions 39 suite.Suite 40 } 41 ) 42 43 func TestCallerInfoSuite(t *testing.T) { 44 suite.Run(t, &callerInfoSuite{}) 45 } 46 47 func (s *callerInfoSuite) SetupTest() { 48 s.Assertions = require.New(s.T()) 49 } 50 51 func (s *callerInfoSuite) TestSetCallerName() { 52 ctx := context.Background() 53 info := GetCallerInfo(ctx) 54 s.Empty(info.CallerName) 55 56 ctx = SetCallerName(ctx, CallerNameSystem) 57 info = GetCallerInfo(ctx) 58 s.Equal(CallerNameSystem, info.CallerName) 59 60 ctx = SetCallerName(ctx, "") 61 info = GetCallerInfo(ctx) 62 s.Equal(CallerNameSystem, info.CallerName) 63 64 newCallerName := "new caller name" 65 ctx = SetCallerName(ctx, newCallerName) 66 info = GetCallerInfo(ctx) 67 s.Equal(newCallerName, info.CallerName) 68 } 69 70 func (s *callerInfoSuite) TestSetCallerType() { 71 ctx := context.Background() 72 info := GetCallerInfo(ctx) 73 s.Empty(info.CallerType) 74 75 ctx = SetCallerType(ctx, CallerTypeBackground) 76 info = GetCallerInfo(ctx) 77 s.Equal(CallerTypeBackground, info.CallerType) 78 79 ctx = SetCallerName(ctx, "") 80 info = GetCallerInfo(ctx) 81 s.Equal(CallerTypeBackground, info.CallerType) 82 83 ctx = SetCallerType(ctx, CallerTypeAPI) 84 info = GetCallerInfo(ctx) 85 s.Equal(CallerTypeAPI, info.CallerType) 86 87 ctx = SetCallerType(ctx, CallerTypePreemptable) 88 info = GetCallerInfo(ctx) 89 s.Equal(CallerTypePreemptable, info.CallerType) 90 } 91 92 func (s *callerInfoSuite) TestSetCallOrigin() { 93 ctx := context.Background() 94 info := GetCallerInfo(ctx) 95 s.Empty(info.CallOrigin) 96 97 initiation := "method name" 98 ctx = SetOrigin(ctx, initiation) 99 info = GetCallerInfo(ctx) 100 s.Equal(initiation, info.CallOrigin) 101 102 ctx = SetOrigin(ctx, "") 103 info = GetCallerInfo(ctx) 104 s.Equal(initiation, info.CallOrigin) 105 106 newCallOrigin := "another method name" 107 ctx = SetOrigin(ctx, newCallOrigin) 108 info = GetCallerInfo(ctx) 109 s.Equal(newCallOrigin, info.CallOrigin) 110 } 111 112 func (s *callerInfoSuite) TestSetCallerInfo_PreserveOtherValues() { 113 existingKey := "key" 114 existingValue := "value" 115 callerName := "callerName" 116 callerType := CallerTypeAPI 117 callOrigin := "methodName" 118 119 ctx := metadata.NewIncomingContext( 120 context.Background(), 121 metadata.Pairs(existingKey, existingValue), 122 ) 123 124 ctx = SetCallerInfo(ctx, NewCallerInfo(callerName, callerType, callOrigin)) 125 126 md, ok := metadata.FromIncomingContext(ctx) 127 s.True(ok) 128 s.Equal(existingValue, md.Get(existingKey)[0]) 129 s.Equal(callerName, md.Get(callerNameHeaderName)[0]) 130 s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) 131 s.Equal(callOrigin, md.Get(callOriginHeaderName)[0]) 132 s.Len(md, 4) 133 } 134 135 func (s *callerInfoSuite) TestSetCallerInfo_NoExistingCallerInfo() { 136 callerName := CallerNameSystem 137 callerType := CallerTypeAPI 138 callOrigin := "methodName" 139 140 ctx := SetCallerInfo(context.Background(), CallerInfo{ 141 CallerName: callerName, 142 CallerType: callerType, 143 CallOrigin: callOrigin, 144 }) 145 146 md, ok := metadata.FromIncomingContext(ctx) 147 s.True(ok) 148 s.Equal(callerName, md.Get(callerNameHeaderName)[0]) 149 s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) 150 s.Equal(callOrigin, md.Get(callOriginHeaderName)[0]) 151 s.Len(md, 3) 152 } 153 154 func (s *callerInfoSuite) TestSetCallerInfo_WithExistingCallerInfo() { 155 callerName := CallerNameSystem 156 callerType := CallerTypeBackground 157 callOrigin := "methodName" 158 159 ctx := SetCallerName(context.Background(), callerName) 160 ctx = SetCallerType(ctx, CallerTypeAPI) 161 ctx = SetOrigin(ctx, callOrigin) 162 163 ctx = SetCallerInfo(ctx, CallerInfo{ 164 CallerName: "", 165 CallerType: callerType, 166 CallOrigin: "", 167 }) 168 169 md, ok := metadata.FromIncomingContext(ctx) 170 s.True(ok) 171 s.Equal(callerName, md.Get(callerNameHeaderName)[0]) 172 s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) 173 s.Equal(callOrigin, md.Get(callOriginHeaderName)[0]) 174 s.Len(md, 3) 175 } 176 177 func (s *callerInfoSuite) TestSetCallerInfo_WithPartialCallerInfo() { 178 callerName := CallerNameSystem 179 callerType := CallerTypeAPI 180 181 ctx := SetCallerType(context.Background(), callerType) 182 183 ctx = SetCallerInfo(ctx, CallerInfo{ 184 CallerName: callerName, 185 }) 186 187 md, ok := metadata.FromIncomingContext(ctx) 188 s.True(ok) 189 s.Equal(callerName, md.Get(callerNameHeaderName)[0]) 190 s.Equal(callerType, md.Get(CallerTypeHeaderName)[0]) 191 s.Empty(md.Get(callOriginHeaderName)) 192 s.Len(md, 2) 193 }