github.com/cloudwego/kitex@v0.9.0/pkg/utils/kitexutil/kitexutil_test.go (about) 1 /* 2 * Copyright 2023 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package kitexutil 18 19 import ( 20 "context" 21 "net" 22 "os" 23 "reflect" 24 "testing" 25 26 mocks "github.com/cloudwego/kitex/internal/mocks/thrift/fast" 27 "github.com/cloudwego/kitex/internal/test" 28 "github.com/cloudwego/kitex/pkg/rpcinfo" 29 "github.com/cloudwego/kitex/pkg/utils" 30 "github.com/cloudwego/kitex/transport" 31 ) 32 33 var ( 34 testRi rpcinfo.RPCInfo 35 testCtx context.Context 36 panicCtx context.Context 37 caller = "kitexutil.from.service" 38 callee = "kitexutil.to.service" 39 idlServiceName = "MockService" 40 fromAddr = utils.NewNetAddr("test", "127.0.0.1:12345") 41 fromMethod = "from_method" 42 method = "method" 43 tp = transport.TTHeader 44 ) 45 46 func TestMain(m *testing.M) { 47 testRi = buildRPCInfo() 48 49 testCtx = context.Background() 50 testCtx = rpcinfo.NewCtxWithRPCInfo(testCtx, testRi) 51 panicCtx = rpcinfo.NewCtxWithRPCInfo(context.Background(), &panicRPCInfo{}) 52 53 os.Exit(m.Run()) 54 } 55 56 func TestGetCaller(t *testing.T) { 57 type args struct { 58 ctx context.Context 59 } 60 tests := []struct { 61 name string 62 args args 63 want string 64 want1 bool 65 }{ 66 {name: "Success", args: args{testCtx}, want: caller, want1: true}, 67 {name: "Failure", args: args{context.Background()}, want: "", want1: false}, 68 {name: "Panic recovered", args: args{panicCtx}, want: "", want1: false}, 69 } 70 for _, tt := range tests { 71 t.Run(tt.name, func(t *testing.T) { 72 got, got1 := GetCaller(tt.args.ctx) 73 if got != tt.want { 74 t.Errorf("GetCaller() got = %v, want %v", got, tt.want) 75 } 76 if got1 != tt.want1 { 77 t.Errorf("GetCaller() got1 = %v, want %v", got1, tt.want1) 78 } 79 }) 80 } 81 } 82 83 func TestGetCallerAddr(t *testing.T) { 84 type args struct { 85 ctx context.Context 86 } 87 tests := []struct { 88 name string 89 args args 90 want net.Addr 91 want1 bool 92 }{ 93 {name: "Success", args: args{testCtx}, want: fromAddr, want1: true}, 94 {name: "Failure", args: args{context.Background()}, want: nil, want1: false}, 95 {name: "Panic recovered", args: args{panicCtx}, want: nil, want1: false}, 96 } 97 for _, tt := range tests { 98 t.Run(tt.name, func(t *testing.T) { 99 got, got1 := GetCallerAddr(tt.args.ctx) 100 if !reflect.DeepEqual(got, tt.want) { 101 t.Errorf("GetCallerAddr() got = %v, want %v", got, tt.want) 102 } 103 if got1 != tt.want1 { 104 t.Errorf("GetCallerAddr() got1 = %v, want %v", got1, tt.want1) 105 } 106 }) 107 } 108 } 109 110 func TestGetCallerIP(t *testing.T) { 111 ip, ok := GetCallerIP(testCtx) 112 test.Assert(t, ok) 113 test.Assert(t, ip == "127.0.0.1", ip) 114 115 ri := buildRPCInfo() 116 rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(utils.NewNetAddr("test", "127.0.0.1")) 117 ip, ok = GetCallerIP(rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)) 118 test.Assert(t, ok) 119 test.Assert(t, ip == "127.0.0.1", ip) 120 121 ip, ok = GetCallerIP(context.Background()) 122 test.Assert(t, !ok) 123 test.Assert(t, ip == "", ip) 124 125 ip, ok = GetCallerIP(panicCtx) 126 test.Assert(t, !ok) 127 test.Assert(t, ip == "", ip) 128 129 rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(utils.NewNetAddr("test", "")) 130 ip, ok = GetCallerIP(rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)) 131 test.Assert(t, !ok) 132 test.Assert(t, ip == "", ip) 133 } 134 135 func TestGetMethod(t *testing.T) { 136 type args struct { 137 ctx context.Context 138 } 139 tests := []struct { 140 name string 141 args args 142 want string 143 want1 bool 144 }{ 145 {name: "Success", args: args{testCtx}, want: method, want1: true}, 146 {name: "Failure", args: args{context.Background()}, want: "", want1: false}, 147 } 148 for _, tt := range tests { 149 t.Run(tt.name, func(t *testing.T) { 150 got, got1 := GetMethod(tt.args.ctx) 151 if got != tt.want { 152 t.Errorf("GetMethod() got = %v, want %v", got, tt.want) 153 } 154 if got1 != tt.want1 { 155 t.Errorf("GetMethod() got1 = %v, want %v", got1, tt.want1) 156 } 157 }) 158 } 159 } 160 161 func TestGetCallerHandlerMethod(t *testing.T) { 162 type args struct { 163 ctx context.Context 164 } 165 tests := []struct { 166 name string 167 args args 168 want string 169 want1 bool 170 }{ 171 {name: "Success", args: args{testCtx}, want: fromMethod, want1: true}, 172 {name: "Failure", args: args{context.Background()}, want: "", want1: false}, 173 {name: "Panic recovered", args: args{panicCtx}, want: "", want1: false}, 174 } 175 for _, tt := range tests { 176 t.Run(tt.name, func(t *testing.T) { 177 got, got1 := GetCallerHandlerMethod(tt.args.ctx) 178 if !reflect.DeepEqual(got, tt.want) { 179 t.Errorf("GetCallerHandlerMethod() got = %v, want %v", got, tt.want) 180 } 181 if got1 != tt.want1 { 182 t.Errorf("GetCallerHandlerMethod() got1 = %v, want %v", got1, tt.want1) 183 } 184 }) 185 } 186 } 187 188 func TestGetIDLServiceName(t *testing.T) { 189 type args struct { 190 ctx context.Context 191 } 192 tests := []struct { 193 name string 194 args args 195 want string 196 want1 bool 197 }{ 198 {name: "Success", args: args{testCtx}, want: idlServiceName, want1: true}, 199 {name: "Failure", args: args{context.Background()}, want: "", want1: false}, 200 {name: "Panic recovered", args: args{panicCtx}, want: "", want1: false}, 201 } 202 for _, tt := range tests { 203 t.Run(tt.name, func(t *testing.T) { 204 got, got1 := GetIDLServiceName(tt.args.ctx) 205 if !reflect.DeepEqual(got, tt.want) { 206 t.Errorf("GetCallerHandlerMethod() got = %v, want %v", got, tt.want) 207 } 208 if got1 != tt.want1 { 209 t.Errorf("GetCallerHandlerMethod() got1 = %v, want %v", got1, tt.want1) 210 } 211 }) 212 } 213 } 214 215 func TestGetRPCInfo(t *testing.T) { 216 type args struct { 217 ctx context.Context 218 } 219 tests := []struct { 220 name string 221 args args 222 want rpcinfo.RPCInfo 223 want1 bool 224 }{ 225 {name: "Success", args: args{testCtx}, want: testRi, want1: true}, 226 {name: "Failure", args: args{context.Background()}, want: nil, want1: false}, 227 } 228 for _, tt := range tests { 229 t.Run(tt.name, func(t *testing.T) { 230 got, got1 := GetRPCInfo(tt.args.ctx) 231 if !reflect.DeepEqual(got, tt.want) { 232 t.Errorf("GetRPCInfo() got = %v, want %v", got, tt.want) 233 } 234 if got1 != tt.want1 { 235 t.Errorf("GetRPCInfo() got1 = %v, want %v", got1, tt.want1) 236 } 237 }) 238 } 239 } 240 241 func TestGetCtxTransportProtocol(t *testing.T) { 242 type args struct { 243 ctx context.Context 244 } 245 tests := []struct { 246 name string 247 args args 248 want string 249 want1 bool 250 }{ 251 {name: "Success", args: args{testCtx}, want: tp.String(), want1: true}, 252 {name: "Failure", args: args{context.Background()}, want: "", want1: false}, 253 {name: "Panic recovered", args: args{panicCtx}, want: "", want1: false}, 254 } 255 for _, tt := range tests { 256 t.Run(tt.name, func(t *testing.T) { 257 got, got1 := GetTransportProtocol(tt.args.ctx) 258 if got != tt.want { 259 t.Errorf("GetTransportProtocol() got = %v, want %v", got, tt.want) 260 } 261 if got1 != tt.want1 { 262 t.Errorf("GetTransportProtocol() got1 = %v, want %v", got1, tt.want1) 263 } 264 }) 265 } 266 } 267 268 func TestGetRealRequest(t *testing.T) { 269 req := &mocks.MockReq{} 270 arg := &mocks.MockTestArgs{Req: req} 271 type args struct { 272 req interface{} 273 } 274 tests := []struct { 275 name string 276 args args 277 want interface{} 278 }{ 279 {name: "success", args: args{arg}, want: req}, 280 {name: "nil input", args: args{nil}, want: nil}, 281 {name: "wrong interface", args: args{req}, want: nil}, 282 } 283 for _, tt := range tests { 284 t.Run(tt.name, func(t *testing.T) { 285 if got := GetRealReqFromKitexArgs(tt.args.req); !reflect.DeepEqual(got, tt.want) { 286 t.Errorf("GetRealReqFromKitexArgs() = %v, want %v", got, tt.want) 287 } 288 }) 289 } 290 } 291 292 func TestGetRealResponse(t *testing.T) { 293 success := "success" 294 result := &mocks.MockTestResult{Success: &success} 295 type args struct { 296 resp interface{} 297 } 298 tests := []struct { 299 name string 300 args args 301 want interface{} 302 }{ 303 {name: "success", args: args{result}, want: &success}, 304 {name: "nil input", args: args{nil}, want: nil}, 305 {name: "wrong interface", args: args{success}, want: nil}, 306 } 307 for _, tt := range tests { 308 t.Run(tt.name, func(t *testing.T) { 309 if got := GetRealRespFromKitexResult(tt.args.resp); !reflect.DeepEqual(got, tt.want) { 310 t.Errorf("GetRealRespFromKitexResult() = %v, want %v", got, tt.want) 311 } 312 }) 313 } 314 } 315 316 func buildRPCInfo() rpcinfo.RPCInfo { 317 from := rpcinfo.NewEndpointInfo(caller, fromMethod, fromAddr, nil) 318 to := rpcinfo.NewEndpointInfo(callee, method, nil, nil) 319 ink := rpcinfo.NewInvocation(idlServiceName, method) 320 config := rpcinfo.NewRPCConfig() 321 config.(rpcinfo.MutableRPCConfig).SetTransportProtocol(tp) 322 323 stats := rpcinfo.NewRPCStats() 324 ri := rpcinfo.NewRPCInfo(from, to, ink, config, stats) 325 return ri 326 } 327 328 type panicRPCInfo struct{} 329 330 func (m *panicRPCInfo) From() rpcinfo.EndpointInfo { panic("Panic when invoke From") } 331 func (m *panicRPCInfo) To() rpcinfo.EndpointInfo { panic("Panic when invoke To") } 332 func (m *panicRPCInfo) Invocation() rpcinfo.Invocation { panic("Panic when invoke Invocation") } 333 func (m *panicRPCInfo) Config() rpcinfo.RPCConfig { panic("Panic when invoke Config") } 334 func (m *panicRPCInfo) Stats() rpcinfo.RPCStats { panic("Panic when invoke Stats") }