trpc.group/trpc-go/trpc-go@v1.0.3/codec/message_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package codec_test 15 16 import ( 17 "context" 18 "errors" 19 "net" 20 "reflect" 21 "testing" 22 "time" 23 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 27 28 trpc "trpc.group/trpc-go/trpc-go" 29 "trpc.group/trpc-go/trpc-go/codec" 30 "trpc.group/trpc-go/trpc-go/errs" 31 "trpc.group/trpc-go/trpc-go/log" 32 ) 33 34 // go test -v -coverprofile=cover.out 35 // go tool cover -func=cover.out 36 37 func TestPutBackMessage(t *testing.T) { 38 ctx := context.Background() 39 _, msg := codec.WithCloneMessage(ctx) 40 type foo struct { 41 I int 42 } 43 44 msg.WithRemoteAddr(&net.TCPAddr{IP: net.ParseIP("127.0.0.2")}) 45 msg.WithLocalAddr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1")}) 46 msg.WithNamespace("2") 47 msg.WithEnvName("3") 48 msg.WithSetName("4") 49 msg.WithEnvTransfer("5") 50 msg.WithRequestTimeout(time.Second) 51 msg.WithSerializationType(1) 52 msg.WithCompressType(2) 53 msg.WithServerRPCName("6") 54 msg.WithClientRPCName("7") 55 msg.WithCallerServiceName("8") 56 msg.WithCalleeServiceName("9") 57 msg.WithCallerApp("10") 58 msg.WithCallerServer("11") 59 msg.WithCallerService("12") 60 msg.WithCallerMethod("13") 61 msg.WithCalleeApp("14") 62 msg.WithCalleeServer("15") 63 msg.WithCalleeService("16") 64 msg.WithCalleeMethod("17") 65 msg.WithCalleeContainerName("18") 66 msg.WithServerMetaData(codec.MetaData{"a": []byte("1")}) 67 msg.WithFrameHead(foo{I: 1}) 68 msg.WithServerReqHead(foo{I: 2}) 69 msg.WithServerRspHead(foo{I: 3}) 70 msg.WithDyeing(true) 71 msg.WithDyeingKey("19") 72 msg.WithServerRspErr(errors.New("err1")) 73 msg.WithClientMetaData(codec.MetaData{"b": []byte("2")}) 74 msg.WithClientReqHead(foo{I: 4}) 75 msg.WithClientRspErr(errors.New("err2")) 76 msg.WithClientRspHead(foo{I: 5}) 77 msg.WithLogger(foo{I: 6}) 78 msg.WithRequestID(3) 79 msg.WithStreamID(4) 80 msg.WithStreamFrame(foo{I: 6}) 81 msg.WithCalleeSetName("20") 82 msg.WithCommonMeta(codec.CommonMeta{21: []byte("hello")}) 83 msg.WithCallType(codec.SendOnly) 84 85 codec.PutBackMessage(msg) 86 87 ctx2 := context.Background() 88 _, msg2 := codec.WithNewMessage(ctx2) 89 90 assert.Nil(t, msg2.FrameHead()) 91 assert.Equal(t, time.Duration(0), msg2.RequestTimeout()) 92 assert.Equal(t, 0, msg2.SerializationType()) 93 assert.Equal(t, 0, msg2.CompressType()) 94 assert.Equal(t, false, msg2.Dyeing()) 95 assert.Equal(t, "", msg2.DyeingKey()) 96 assert.Equal(t, "", msg2.ServerRPCName()) 97 assert.Equal(t, "", msg2.ClientRPCName()) 98 assert.Nil(t, msg2.ServerMetaData()) 99 assert.Nil(t, msg2.ClientMetaData()) 100 assert.Equal(t, "", msg2.CallerServiceName()) 101 assert.Equal(t, "", msg2.CalleeServiceName()) 102 assert.Equal(t, "", msg2.CalleeContainerName()) 103 assert.Nil(t, msg2.ServerRspErr()) 104 assert.Nil(t, msg2.ClientRspErr()) 105 assert.Nil(t, msg2.ServerReqHead()) 106 assert.Nil(t, msg2.ServerRspHead()) 107 assert.Nil(t, msg2.ClientReqHead()) 108 assert.Nil(t, msg2.ClientRspHead()) 109 assert.Nil(t, msg2.LocalAddr()) 110 assert.Nil(t, msg2.RemoteAddr()) 111 assert.Nil(t, msg2.Logger()) 112 assert.Equal(t, "", msg2.CallerApp()) 113 assert.Equal(t, "", msg2.CallerServer()) 114 assert.Equal(t, "", msg2.CallerService()) 115 assert.Equal(t, "", msg2.CallerMethod()) 116 assert.Equal(t, "", msg2.CalleeApp()) 117 assert.Equal(t, "", msg2.CalleeServer()) 118 assert.Equal(t, "", msg2.CalleeService()) 119 assert.Equal(t, "", msg2.CalleeMethod()) 120 assert.Equal(t, "", msg2.Namespace()) 121 assert.Equal(t, "", msg2.SetName()) 122 assert.Equal(t, "", msg2.EnvName()) 123 assert.Equal(t, "", msg2.EnvTransfer()) 124 assert.Equal(t, uint32(0), msg2.RequestID()) 125 assert.Nil(t, msg2.StreamFrame()) 126 assert.Equal(t, uint32(0), msg2.StreamID()) 127 assert.Equal(t, "", msg2.CalleeSetName()) 128 assert.Nil(t, msg2.CommonMeta()) 129 assert.Equal(t, codec.SendAndRecv, msg2.CallType()) 130 131 } 132 133 func TestRegisterMessage(t *testing.T) { 134 ctx := context.Background() 135 m0 := codec.Message(ctx) 136 assert.NotNil(t, m0) 137 assert.Equal(t, ctx, m0.Context()) 138 ctx, m0 = codec.WithCloneMessage(ctx) 139 assert.NotNil(t, m0) 140 assert.Equal(t, ctx, m0.Context()) 141 142 meta := codec.MetaData{} 143 reqhead := &trpcpb.RequestProtocol{} 144 rsphead := &trpcpb.ResponseProtocol{} 145 146 meta["key"] = []byte("value") 147 clone := meta.Clone() 148 assert.Equal(t, []byte("value"), clone["key"]) 149 150 ctx, msg := codec.WithNewMessage(ctx) 151 assert.NotNil(t, msg) 152 assert.NotNil(t, ctx) 153 assert.Equal(t, ctx, msg.Context()) 154 155 msg.WithRequestTimeout(time.Second) 156 assert.Equal(t, time.Second, msg.RequestTimeout()) 157 msg.WithSerializationType(codec.SerializationTypePB) 158 assert.Equal(t, codec.SerializationTypePB, msg.SerializationType()) 159 msg.WithServerRPCName("/package.service/method") 160 msg.WithServerRPCName("/package.service/method") // dup set 161 assert.Equal(t, "/package.service/method", msg.ServerRPCName()) 162 msg.WithServerMetaData(meta) 163 assert.Equal(t, meta, msg.ServerMetaData()) 164 msg.WithServerMetaData(nil) 165 assert.NotNil(t, msg.ServerMetaData()) 166 msg.WithServerReqHead(reqhead) 167 assert.Equal(t, reqhead, msg.ServerReqHead().(*trpcpb.RequestProtocol)) 168 msg.WithServerRspHead(rsphead) 169 assert.Equal(t, rsphead, msg.ServerRspHead().(*trpcpb.ResponseProtocol)) 170 msg.WithDyeing(true) 171 assert.Equal(t, true, msg.Dyeing()) 172 msg.WithDyeingKey("hellotrpc") 173 assert.Equal(t, "hellotrpc", msg.DyeingKey()) 174 175 var addr net.Addr 176 msg.WithRemoteAddr(addr) 177 assert.Equal(t, addr, msg.RemoteAddr()) 178 179 msg.WithLocalAddr(addr) 180 assert.Equal(t, addr, msg.LocalAddr()) 181 182 h := trpc.FrameHead{} 183 msg.WithFrameHead(h) 184 assert.Equal(t, h, msg.FrameHead()) 185 186 msg.WithCompressType(1) 187 assert.Equal(t, 1, msg.CompressType()) 188 189 msg.WithCallerApp("callerApp") 190 assert.Equal(t, "callerApp", msg.CallerApp()) 191 msg.WithCallerServer("callerServer") 192 assert.Equal(t, "callerServer", msg.CallerServer()) 193 msg.WithCallerService("callerService") 194 assert.Equal(t, "callerService", msg.CallerService()) 195 msg.WithCallerMethod("callerMethod") 196 assert.Equal(t, "callerMethod", msg.CallerMethod()) 197 msg.WithCalleeApp("calleeApp") 198 assert.Equal(t, "calleeApp", msg.CalleeApp()) 199 msg.WithCalleeServer("calleeServer") 200 assert.Equal(t, "calleeServer", msg.CalleeServer()) 201 msg.WithCalleeService("calleeService") 202 assert.Equal(t, "calleeService", msg.CalleeService()) 203 msg.WithCalleeMethod("calleeMethod") 204 assert.Equal(t, "calleeMethod", msg.CalleeMethod()) 205 msg.WithSetName("setName") 206 assert.Equal(t, "setName", msg.SetName()) 207 msg.WithCalleeSetName("calleeSetName") 208 assert.Equal(t, "calleeSetName", msg.CalleeSetName()) 209 msg.WithEnvName("test") 210 assert.Equal(t, "test", msg.EnvName()) 211 msg.WithNamespace("Production") 212 assert.Equal(t, "Production", msg.Namespace()) 213 msg.WithEnvTransfer("test-test") 214 assert.Equal(t, "test-test", msg.EnvTransfer()) 215 msg.WithCalleeContainerName("container") 216 assert.Equal(t, "container", msg.CalleeContainerName()) 217 218 msg.WithLogger(log.DefaultLogger) 219 assert.NotNil(t, msg.Logger()) 220 221 msg.WithCallType(codec.SendOnly) 222 assert.Equal(t, msg.CallType(), codec.SendOnly) 223 } 224 225 func TestMoreRegisterMessage(t *testing.T) { 226 ctx := context.Background() 227 meta := codec.MetaData{} 228 commonMeta := codec.CommonMeta{32: []byte("aaa")} 229 ctx, msg := codec.WithNewMessage(ctx) 230 reqhead := &trpcpb.RequestProtocol{} 231 rsphead := &trpcpb.ResponseProtocol{} 232 // client codec marshal 233 msg.WithClientRPCName("/package.service/method") 234 msg.WithClientRPCName("/package.service/method") // dup set 235 assert.Equal(t, "/package.service/method", msg.ClientRPCName()) 236 msg.WithClientMetaData(meta) 237 assert.Equal(t, meta, msg.ClientMetaData()) 238 msg.WithClientMetaData(nil) 239 assert.NotNil(t, msg.ClientMetaData()) 240 msg.WithCommonMeta(commonMeta) 241 assert.Equal(t, commonMeta, msg.CommonMeta()) 242 msg.WithCallerServiceName("package.service") 243 msg.WithCallerServiceName("package.service") // dup set 244 assert.Equal(t, "package.service", msg.CallerServiceName()) 245 msg.WithCalleeServiceName("package.service") 246 msg.WithCalleeServiceName("package.service") // dup set 247 assert.Equal(t, "package.service", msg.CalleeServiceName()) 248 msg.WithClientReqHead(reqhead) 249 assert.Equal(t, reqhead, msg.ClientReqHead().(*trpcpb.RequestProtocol)) 250 msg.WithClientRspHead(rsphead) 251 assert.Equal(t, rsphead, msg.ClientRspHead().(*trpcpb.ResponseProtocol)) 252 msg.WithCompressType(1) 253 assert.Equal(t, msg.CompressType(), 1) 254 255 // client codec unmarshal 256 msg.WithClientRspErr(errs.ErrServerNoResponse) 257 assert.Equal(t, errs.ErrServerNoResponse, msg.ClientRspErr()) 258 259 // trpc inner logic 260 assert.Nil(t, msg.ServerRspErr()) 261 msg.WithServerRspErr(errs.ErrServerNoResponse) 262 assert.Equal(t, errs.ErrServerNoResponse, msg.ServerRspErr()) 263 msg.WithServerRspErr(errors.New("no trpc errs")) 264 assert.EqualValues(t, int32(999), msg.ServerRspErr().Code) 265 266 m1 := codec.Message(ctx) 267 assert.Equal(t, msg, m1) 268 269 ctx, m2 := codec.WithCloneMessage(ctx) 270 assert.Equal(t, m2.ServerReqHead(), m1.ServerReqHead()) 271 assert.Equal(t, m2.ServerRspHead(), m1.ServerRspHead()) 272 assert.Equal(t, m2.CallerServiceName(), m1.CallerServiceName()) 273 assert.Equal(t, m2.RequestTimeout(), m1.RequestTimeout()) 274 assert.Equal(t, m2.ServerRPCName(), m1.ServerRPCName()) 275 assert.Equal(t, m2.SerializationType(), m1.SerializationType()) 276 assert.Equal(t, true, reflect.DeepEqual(m2.ServerMetaData(), m1.ServerMetaData())) 277 assert.Equal(t, m2.Dyeing(), m1.Dyeing()) 278 assert.Equal(t, m2.DyeingKey(), m1.DyeingKey()) 279 assert.Equal(t, m2.CommonMeta(), m1.CommonMeta()) 280 assert.NotEqual(t, m2.CompressType(), m1.CompressType()) 281 282 codec.PutBackMessage(msg) 283 ctx, m3 := codec.WithNewMessage(ctx) 284 assert.Equal(t, m3, msg) 285 assert.Equal(t, m3.CalleeApp(), "") 286 _, m4 := codec.WithNewMessage(ctx) 287 assert.NotEqual(t, m4, m1) 288 289 var fakemsg codec.Msg = nil 290 codec.PutBackMessage(fakemsg) 291 } 292 293 // TestWithCallerServiceName WithCallerServiceName 单测 294 func TestWithCallerServiceName(t *testing.T) { 295 ctx := trpc.BackgroundContext() 296 msg := codec.Message(ctx) 297 298 msg.WithCallerServiceName("trpc") 299 assert.Equal(t, "trpc", msg.CallerApp()) 300 assert.Equal(t, "", msg.CallerServer()) 301 assert.Equal(t, "", msg.CallerService()) 302 303 msg.WithCallerServiceName("app.server") 304 assert.Equal(t, "app", msg.CallerApp()) 305 assert.Equal(t, "server", msg.CallerServer()) 306 assert.Equal(t, "", msg.CallerService()) 307 308 msg.WithCallerServiceName("app.server.service") 309 assert.Equal(t, "app", msg.CallerApp()) 310 assert.Equal(t, "server", msg.CallerServer()) 311 assert.Equal(t, "service", msg.CallerService()) 312 313 msg.WithCallerServiceName("trpc.app.server.service") 314 assert.Equal(t, "app", msg.CallerApp()) 315 assert.Equal(t, "server", msg.CallerServer()) 316 assert.Equal(t, "service", msg.CallerService()) 317 318 msg.WithCallerServiceName("trpc.app.server.service.new") 319 assert.Equal(t, "app", msg.CallerApp()) 320 assert.Equal(t, "server", msg.CallerServer()) 321 assert.Equal(t, "service.new", msg.CallerService()) 322 323 msg.WithCallerServiceName("*") 324 assert.Equal(t, "*", msg.CallerServiceName()) 325 assert.Equal(t, "app", msg.CallerApp()) 326 assert.Equal(t, "server", msg.CallerServer()) 327 328 msg.WithCalleeServiceName("trpc") 329 assert.Equal(t, "trpc", msg.CalleeApp()) 330 assert.Equal(t, "", msg.CalleeServer()) 331 assert.Equal(t, "", msg.CalleeService()) 332 333 msg.WithCalleeServiceName("app.server.service") 334 assert.Equal(t, "app", msg.CalleeApp()) 335 assert.Equal(t, "server", msg.CalleeServer()) 336 assert.Equal(t, "service", msg.CalleeService()) 337 338 msg.WithCalleeServiceName("trpc.app.server.service") 339 assert.Equal(t, "app", msg.CalleeApp()) 340 assert.Equal(t, "server", msg.CalleeServer()) 341 assert.Equal(t, "service", msg.CalleeService()) 342 343 msg.WithCalleeServiceName("trpc.app.server.service.new") 344 assert.Equal(t, "app", msg.CalleeApp()) 345 assert.Equal(t, "server", msg.CalleeServer()) 346 assert.Equal(t, "service.new", msg.CalleeService()) 347 348 msg.WithCalleeServiceName("*") 349 assert.Equal(t, "*", msg.CalleeServiceName()) 350 assert.Equal(t, "app", msg.CalleeApp()) 351 assert.Equal(t, "server", msg.CalleeServer()) 352 353 } 354 355 func TestMsg_CopyMsg_1_CIFunctionStatementsMustLessThan80Lines(t *testing.T) { 356 ctx := context.Background() 357 msg := codec.Message(ctx) 358 359 msg.WithRemoteAddr(&net.TCPAddr{IP: net.ParseIP("127.0.0.2")}) 360 msg.WithLocalAddr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1")}) 361 msg.WithNamespace("2") 362 msg.WithEnvName("3") 363 msg.WithSetName("4") 364 msg.WithEnvTransfer("5") 365 msg.WithRequestTimeout(time.Second) 366 msg.WithSerializationType(1) 367 msg.WithCompressType(2) 368 msg.WithServerRPCName("6") 369 msg.WithClientRPCName("7") 370 msg.WithCallerServiceName("8") 371 msg.WithCalleeServiceName("9") 372 msg.WithCallerApp("10") 373 msg.WithCallerServer("11") 374 msg.WithCallerService("12") 375 msg.WithCallerMethod("13") 376 377 _, newMsg := codec.WithNewMessage(ctx) 378 codec.CopyMsg(newMsg, msg) 379 380 require.False(t, reflect.DeepEqual(msg.Context(), newMsg.Context())) 381 require.True(t, reflect.DeepEqual(msg.RemoteAddr(), newMsg.RemoteAddr())) 382 require.True(t, reflect.DeepEqual(msg.LocalAddr(), newMsg.LocalAddr())) 383 require.True(t, reflect.DeepEqual(msg.Namespace(), newMsg.Namespace())) 384 require.True(t, reflect.DeepEqual(msg.EnvName(), newMsg.EnvName())) 385 require.True(t, reflect.DeepEqual(msg.SetName(), newMsg.SetName())) 386 require.True(t, reflect.DeepEqual(msg.EnvTransfer(), newMsg.EnvTransfer())) 387 require.True(t, reflect.DeepEqual(msg.RequestTimeout(), newMsg.RequestTimeout())) 388 require.True(t, reflect.DeepEqual(msg.SerializationType(), newMsg.SerializationType())) 389 require.True(t, reflect.DeepEqual(msg.CompressType(), newMsg.CompressType())) 390 require.True(t, reflect.DeepEqual(msg.ServerRPCName(), newMsg.ServerRPCName())) 391 require.True(t, reflect.DeepEqual(msg.ClientRPCName(), newMsg.ClientRPCName())) 392 require.True(t, reflect.DeepEqual(msg.CallerServiceName(), newMsg.CallerServiceName())) 393 require.True(t, reflect.DeepEqual(msg.CalleeServiceName(), newMsg.CalleeServiceName())) 394 require.True(t, reflect.DeepEqual(msg.CallerApp(), newMsg.CallerApp())) 395 require.True(t, reflect.DeepEqual(msg.CallerServer(), newMsg.CallerServer())) 396 require.True(t, reflect.DeepEqual(msg.CallerService(), newMsg.CallerService())) 397 require.True(t, reflect.DeepEqual(msg.CallerMethod(), newMsg.CallerMethod())) 398 } 399 400 func TestMsg_CopyMsg_2_CIFunctionStatementsMustLessThan80Lines(t *testing.T) { 401 ctx := context.Background() 402 msg := codec.Message(ctx) 403 type foo struct { 404 I int 405 } 406 407 msg.WithCalleeApp("14") 408 msg.WithCalleeServer("15") 409 msg.WithCalleeService("16") 410 msg.WithCalleeMethod("17") 411 msg.WithCalleeContainerName("18") 412 msg.WithServerMetaData(codec.MetaData{"a": []byte("1")}) 413 msg.WithFrameHead(foo{I: 1}) 414 msg.WithServerReqHead(foo{I: 2}) 415 msg.WithServerRspHead(foo{I: 3}) 416 msg.WithDyeing(true) 417 msg.WithDyeingKey("19") 418 msg.WithServerRspErr(errors.New("err1")) 419 msg.WithClientMetaData(codec.MetaData{"b": []byte("2")}) 420 msg.WithClientReqHead(foo{I: 4}) 421 msg.WithClientRspErr(errors.New("err2")) 422 msg.WithClientRspHead(foo{I: 5}) 423 msg.WithLogger(foo{I: 6}) 424 msg.WithRequestID(3) 425 msg.WithStreamID(4) 426 msg.WithStreamFrame(foo{I: 6}) 427 msg.WithCalleeSetName("20") 428 msg.WithCommonMeta(codec.CommonMeta{21: []byte("hello")}) 429 msg.WithCallType(codec.SendOnly) 430 431 _, newMsg := codec.WithNewMessage(ctx) 432 codec.CopyMsg(newMsg, msg) 433 434 require.False(t, reflect.DeepEqual(msg.Context(), newMsg.Context())) 435 require.True(t, reflect.DeepEqual(msg.CalleeApp(), newMsg.CalleeApp())) 436 require.True(t, reflect.DeepEqual(msg.CalleeServer(), newMsg.CalleeServer())) 437 require.True(t, reflect.DeepEqual(msg.CalleeService(), newMsg.CalleeService())) 438 require.True(t, reflect.DeepEqual(msg.CalleeMethod(), newMsg.CalleeMethod())) 439 require.True(t, reflect.DeepEqual(msg.CalleeContainerName(), newMsg.CalleeContainerName())) 440 require.True(t, reflect.DeepEqual(msg.ServerMetaData(), newMsg.ServerMetaData())) 441 require.True(t, reflect.DeepEqual(msg.FrameHead(), newMsg.FrameHead())) 442 require.True(t, reflect.DeepEqual(msg.ServerReqHead(), newMsg.ServerReqHead())) 443 require.True(t, reflect.DeepEqual(msg.ServerRspHead(), newMsg.ServerRspHead())) 444 require.True(t, reflect.DeepEqual(msg.Dyeing(), newMsg.Dyeing())) 445 require.True(t, reflect.DeepEqual(msg.DyeingKey(), newMsg.DyeingKey())) 446 require.True(t, reflect.DeepEqual(msg.ServerRspErr(), newMsg.ServerRspErr())) 447 require.True(t, reflect.DeepEqual(msg.ClientMetaData(), newMsg.ClientMetaData())) 448 require.True(t, reflect.DeepEqual(msg.ClientReqHead(), newMsg.ClientReqHead())) 449 require.True(t, reflect.DeepEqual(msg.ClientRspErr(), newMsg.ClientRspErr())) 450 require.True(t, reflect.DeepEqual(msg.ClientRspHead(), newMsg.ClientRspHead())) 451 require.True(t, reflect.DeepEqual(msg.Logger(), newMsg.Logger())) 452 require.True(t, reflect.DeepEqual(msg.RequestID(), newMsg.RequestID())) 453 require.True(t, reflect.DeepEqual(msg.StreamID(), newMsg.StreamID())) 454 require.True(t, reflect.DeepEqual(msg.StreamFrame(), newMsg.StreamFrame())) 455 require.True(t, reflect.DeepEqual(msg.CalleeSetName(), newMsg.CalleeSetName())) 456 require.True(t, reflect.DeepEqual(msg.CommonMeta(), newMsg.CommonMeta())) 457 require.True(t, reflect.DeepEqual(msg.CallType(), newMsg.CallType())) 458 459 // make sure map is deeply copied. 460 newMsg.ServerMetaData()["aa"] = []byte("11") 461 require.False(t, reflect.DeepEqual(msg.ServerMetaData(), newMsg.ServerMetaData())) 462 newMsg.ClientMetaData()["bb"] = []byte("22") 463 require.False(t, reflect.DeepEqual(msg.ClientMetaData(), newMsg.ClientMetaData())) 464 465 } 466 467 func TestEnsureMessage(t *testing.T) { 468 ctx := context.Background() 469 newCtx, msg := codec.EnsureMessage(ctx) 470 require.NotEqual(t, ctx, newCtx) 471 require.Equal(t, msg, codec.Message(newCtx)) 472 473 ctx = trpc.BackgroundContext() 474 msg = codec.Message(ctx) 475 require.NotNil(t, msg) 476 newCtx, newMsg := codec.EnsureMessage(ctx) 477 require.Equal(t, ctx, newCtx) 478 require.Equal(t, msg, newMsg) 479 } 480 481 func TestSetMethodNameUsingRPCName(t *testing.T) { 482 msg := codec.Message(context.Background()) 483 testSetMethodNameUsingRPCName(t, msg, msg.WithServerRPCName) 484 testSetMethodNameUsingRPCName(t, msg, msg.WithClientRPCName) 485 } 486 487 func testSetMethodNameUsingRPCName(t *testing.T, msg codec.Msg, msgWithRPCName func(string)) { 488 var cases = []struct { 489 name string 490 originalMethod string 491 rpcName string 492 expectMethod string 493 }{ 494 {"normal trpc rpc name", "", "/trpc.app.server.service/method", "method"}, 495 {"normal http url path", "", "/v1/subject/info/get", "/v1/subject/info/get"}, 496 {"invalid trpc rpc name (method name is empty)", "", "trpc.app.server.service", "trpc.app.server.service"}, 497 {"invalid trpc rpc name (method name is not mepty)", "/v1/subject/info/get", "trpc.app.server.service", "/v1/subject/info/get"}, 498 {"valid trpc rpc name will override existing method name", "/v1/subject/info/get", "/trpc.app.server.service/method", "method"}, 499 {"invalid trpc rpc will not override existing method name", "/v1/subject/info/get", "/trpc.app.server.service", "/v1/subject/info/get"}, 500 } 501 502 for _, tt := range cases { 503 t.Run(tt.name, func(t *testing.T) { 504 resetMsgRPCNameAndMethodName(msg) 505 msg.WithCalleeMethod(tt.originalMethod) 506 msgWithRPCName(tt.rpcName) 507 method := msg.CalleeMethod() 508 if method != tt.expectMethod { 509 t.Errorf("given original method %s and rpc name %s, expect new method name %s, got %s", 510 tt.originalMethod, tt.rpcName, tt.expectMethod, method) 511 } 512 }) 513 } 514 } 515 516 func resetMsgRPCNameAndMethodName(msg codec.Msg) { 517 msg.WithCalleeMethod("") 518 msg.WithClientRPCName("") 519 msg.WithServerRPCName("") 520 }