trpc.group/trpc-go/trpc-go@v1.0.3/admin/admin_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package admin 15 16 import ( 17 "context" 18 "encoding/json" 19 "errors" 20 "fmt" 21 "io" 22 "net" 23 "net/http" 24 "os" 25 "reflect" 26 "strings" 27 "sync" 28 "testing" 29 "time" 30 "unsafe" 31 32 "github.com/stretchr/testify/assert" 33 "github.com/stretchr/testify/require" 34 35 "trpc.group/trpc-go/trpc-go/config" 36 "trpc.group/trpc-go/trpc-go/healthcheck" 37 "trpc.group/trpc-go/trpc-go/log" 38 "trpc.group/trpc-go/trpc-go/rpcz" 39 "trpc.group/trpc-go/trpc-go/transport" 40 ) 41 42 const ( 43 testVersion = "v0.2.0-alpha" 44 testAddress = "localhost:0" 45 testConfigPath = "../testdata/trpc_go.yaml" 46 ) 47 48 func newDefaultAdminServer() *Server { 49 s := NewServer( 50 WithVersion(testVersion), 51 WithAddr(testAddress), 52 WithTLS(false), 53 WithReadTimeout(defaultReadTimeout), 54 WithWriteTimeout(defaultWriteTimeout), 55 WithConfigPath(testConfigPath), 56 ) 57 58 s.HandleFunc("/usercmd", userCmd) 59 s.HandleFunc("/errout", errOutput) 60 s.HandleFunc("/panicHandle", panicHandle) 61 62 return s 63 } 64 65 func mustStartAdminServer(t *testing.T, s *Server) { 66 t.Helper() 67 68 go func() { 69 if err := s.Serve(); err != nil { 70 t.Log(err) 71 } 72 }() 73 time.Sleep(200 * time.Millisecond) 74 } 75 76 func TestRPCZFailed(t *testing.T) { 77 s := newDefaultAdminServer() 78 mustStartAdminServer(t, s) 79 t.Cleanup(func() { 80 if err := s.Close(nil); err != nil { 81 t.Log(err) 82 } 83 }) 84 tests := []struct { 85 name string 86 url string 87 errorCode int 88 message string 89 content interface{} 90 }{ 91 { 92 name: "handleSpans failed because query parameter isn't a number", 93 url: fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpansList + "?num=xxx", 94 errorCode: errCodeServer, 95 message: "must be a integer", 96 content: "", 97 }, 98 { 99 name: "handleSpans failed because query parameter isn't a positive integer", 100 url: fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpansList + "?num=-1", 101 errorCode: errCodeServer, 102 message: "must be a non-negative integer", 103 content: nil, 104 }, 105 { 106 name: "handleSpan failed because can't find span_id", 107 url: fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpanGet + "1", 108 errorCode: errCodeServer, 109 message: "cannot find span-id", 110 content: nil, 111 }, 112 { 113 name: "handleSpan failed because query parameter span_id is empty", 114 url: fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpanGet + "", 115 errorCode: errCodeServer, 116 message: "undefined command", 117 content: nil, 118 }, 119 { 120 name: "handleSpan failed because query parameter span_id is negative", 121 url: fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpanGet + "-1", 122 errorCode: errCodeServer, 123 message: "can not be negative", 124 content: nil, 125 }, 126 } 127 for _, tt := range tests { 128 t.Run(tt.name, func(t *testing.T) { 129 r, err := httpRequest(http.MethodGet, tt.url, "") 130 require.Nil(t, err) 131 require.Contains(t, string(r), tt.message) 132 }) 133 } 134 t.Run("url query doesn't match rpcz", func(t *testing.T) { 135 r, err := httpRequest(http.MethodGet, fmt.Sprintf("http://%s", s.server.Addr)+"/cmd/rpcz", "") 136 require.Nil(t, err) 137 require.Contains(t, string(r), "404 page not found") 138 }) 139 } 140 141 type sliceSpanExporter struct { 142 spans []rpcz.ReadOnlySpan 143 } 144 145 func (e *sliceSpanExporter) Export(span *rpcz.ReadOnlySpan) { 146 e.spans = append(e.spans, *span) 147 } 148 149 func TestRPC_Exporter(t *testing.T) { 150 s := newDefaultAdminServer() 151 mustStartAdminServer(t, s) 152 t.Cleanup(func() { 153 if err := s.Close(nil); err != nil { 154 t.Log(err) 155 } 156 }) 157 oldGlobalRPCZ := rpcz.GlobalRPCZ 158 defer func() { 159 rpcz.GlobalRPCZ = oldGlobalRPCZ 160 }() 161 // Given a GlobalRPCZ configured with exporter 162 exporter := &sliceSpanExporter{} 163 rpcz.GlobalRPCZ = rpcz.NewRPCZ(&rpcz.Config{Fraction: 1.0, Capacity: 10, Exporter: exporter}) 164 165 // When End a "server" span with spanID. 166 span := rpcz.SpanFromContext(context.Background()) 167 cs, end := span.NewChild("server") 168 spanID := cs.ID() 169 end.End() 170 171 // Then the exporter contain the span exported by the GlobalRPCZ 172 require.Len(t, exporter.spans, 1) 173 require.Equal(t, spanID, exporter.spans[0].ID) 174 175 // And the GlobalRPCZ still stores a copy of the exported span 176 rRaw, err := httpRequest(http.MethodGet, fmt.Sprintf("http://%s", s.server.Addr)+patternRPCZSpansList+"?num", "") 177 require.Nil(t, err) 178 require.Contains(t, string(rRaw), fmt.Sprint(spanID)) 179 } 180 181 func TestRPCZOk(t *testing.T) { 182 s := newDefaultAdminServer() 183 mustStartAdminServer(t, s) 184 t.Cleanup(func() { 185 if err := s.Close(nil); err != nil { 186 t.Log(err) 187 } 188 }) 189 oldGlobalRPCZ := rpcz.GlobalRPCZ 190 defer func() { 191 rpcz.GlobalRPCZ = oldGlobalRPCZ 192 }() 193 rpcz.GlobalRPCZ = rpcz.NewRPCZ(&rpcz.Config{Fraction: 1.0, Capacity: 10}) 194 span := rpcz.SpanFromContext(context.Background()) 195 196 cs, end := span.NewChild("server") 197 spanID := cs.ID() 198 end.End() 199 200 tests := []struct { 201 name string 202 url string 203 errorCode int 204 message string 205 content interface{} 206 }{ 207 { 208 name: "handleSpans ok query parameter num is empty", 209 url: fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpansList + "?num", 210 content: fmt.Sprintf("1:\n span: (server, %d)\n", spanID), 211 }, 212 { 213 name: "handleSpans ok without any query parameter", 214 url: fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpansList, 215 content: fmt.Sprintf("1:\n span: (server, %d)\n", spanID), 216 }, 217 { 218 name: "handleSpans ok", 219 url: fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpansList + "?num=1", 220 content: fmt.Sprintf("1:\n span: (server, %d)\n", spanID), 221 }, 222 { 223 name: "handleSpan ok", 224 url: fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpanGet + fmt.Sprint(spanID), 225 content: fmt.Sprintf("span: (server, %d)\n", spanID), 226 }, 227 } 228 for _, tt := range tests { 229 t.Run(tt.name, func(t *testing.T) { 230 rRaw, err := httpRequest(http.MethodGet, tt.url, "") 231 r := string(rRaw) 232 require.Nil(t, err) 233 require.Contains(t, r, tt.message) 234 require.Contains(t, r, tt.content) 235 236 }) 237 } 238 } 239 240 func TestCmdVersion(t *testing.T) { 241 s := newDefaultAdminServer() 242 mustStartAdminServer(t, s) 243 t.Cleanup(func() { 244 if err := s.Close(nil); err != nil { 245 t.Log(err) 246 } 247 }) 248 versionURL := fmt.Sprintf("http://%s", s.server.Addr) + "/version" 249 respData, err := httpRequest(http.MethodGet, versionURL, "") 250 if err != nil { 251 require.Nil(t, err, "httpGetBody failed") 252 return 253 } 254 255 res := struct { 256 Errcode int `json:"errorcode"` 257 Message string `json:"message"` 258 Version string `json:"version"` 259 }{} 260 err = json.Unmarshal(respData, &res) 261 require.Nil(t, err, "testAdminServerVersion unmarshal failed") 262 require.Equal(t, 0, res.Errcode) 263 require.Equal(t, testVersion, res.Version) 264 } 265 266 func TestCmdsLogLevel(t *testing.T) { 267 s := newDefaultAdminServer() 268 mustStartAdminServer(t, s) 269 t.Cleanup(func() { 270 if err := s.Close(nil); err != nil { 271 t.Log(err) 272 } 273 }) 274 275 dlogger := log.GetDefaultLogger() 276 277 // Preset test conditions 278 log.Register("default", log.NewZapLog([]log.OutputConfig{ 279 {Writer: log.OutputConsole, Level: "debug"}, 280 {Writer: log.OutputFile, WriteConfig: log.WriteConfig{Filename: "test"}, Level: "info"}, 281 })) 282 283 t.Cleanup(func() { 284 log.Register("default", dlogger) 285 }) 286 287 res := struct { 288 Errcode int `json:"errorcode"` 289 Message string `json:"message"` 290 Level string `json:"level"` 291 PreLevel string `json:"prelevel"` 292 }{} 293 294 t.Run("right case", func(t *testing.T) { 295 logURL := fmt.Sprintf("http://%s", s.server.Addr) + "/cmds/loglevel?logger=default&output=1" 296 // TestGet 297 respData, err := httpRequest(http.MethodGet, logURL, "") 298 require.Nil(t, err, "httpGetBody failed") 299 300 err = json.Unmarshal(respData, &res) 301 require.Nil(t, err, "testAdminServerLogLevel unmarshal failed") 302 require.Equal(t, 0, res.Errcode) 303 require.Equal(t, "info", res.Level) 304 305 // TestUpdate 306 body, err := httpRequest(http.MethodPut, logURL, "value=debug") 307 require.Nil(t, err, "httpRequest failed:", err) 308 err = json.Unmarshal(body, &res) 309 require.Nil(t, err, "Unmarshal failed:", err) 310 require.Equal(t, 0, res.Errcode) 311 require.Equal(t, "info", res.PreLevel) 312 require.Equal(t, "debug", res.Level) 313 }) 314 t.Run("request parameter is empty", func(t *testing.T) { 315 logURL := fmt.Sprintf("http://%s", s.server.Addr) + "/cmds/loglevel" 316 respData, err := httpRequest(http.MethodGet, logURL, "") 317 require.Nil(t, err, "httpGetBody failed") 318 319 err = json.Unmarshal(respData, &res) 320 require.Nil(t, err, "unmarshal failed") 321 require.Equal(t, 0, res.Errcode) 322 require.Equal(t, "debug", res.Level) 323 }) 324 t.Run("failed to parse request parameters", func(t *testing.T) { 325 logURL := fmt.Sprintf("http://%s", s.server.Addr) + "/cmds/loglevel?logger%" 326 respData, err := httpRequest(http.MethodGet, logURL, "") 327 require.Nil(t, err, "httpGetBody failed:", err) 328 329 err = json.Unmarshal(respData, &res) 330 require.Nil(t, err, "Unmarshal failed", err) 331 require.Equal(t, errCodeServer, res.Errcode) 332 }) 333 t.Run("logger is invalid", func(t *testing.T) { 334 logURL := fmt.Sprintf("http://%s", s.server.Addr) + "/cmds/loglevel?logger=invalid" 335 respData, err := httpRequest(http.MethodGet, logURL, "") 336 require.Nil(t, err, "httpGetBody failed:", err) 337 338 err = json.Unmarshal(respData, &res) 339 require.Nil(t, err, "Unmarshal failed", err) 340 require.Equal(t, errCodeServer, res.Errcode) 341 require.Equal(t, "logger invalid not found", res.Message) 342 }) 343 } 344 345 func TestCmdsConfig(t *testing.T) { 346 s := newDefaultAdminServer() 347 mustStartAdminServer(t, s) 348 t.Cleanup(func() { 349 if err := s.Close(nil); err != nil { 350 t.Log(err) 351 } 352 }) 353 configURL := fmt.Sprintf("http://%s//cmds/config", s.server.Addr) 354 res := struct { 355 Errcode int `json:"errorcode"` 356 Message string `json:"message"` 357 Content interface{} `json:"content"` 358 }{} 359 t.Run("failed to read configuration file", func(t *testing.T) { 360 // Replace invalid config path 361 s.config.configPath = "./invalid/invalid.yaml" 362 respData, err := httpRequest(http.MethodGet, configURL, "") 363 // Adjust back to the correct path 364 s.config.configPath = testConfigPath 365 require.Nil(t, err, "httpGetBody failed") 366 367 err = json.Unmarshal(respData, &res) 368 require.Nil(t, err, "unmarshal failed", err) 369 require.Equal(t, errCodeServer, res.Errcode) 370 }) 371 t.Run("failed to get unmarshaler", func(t *testing.T) { 372 // Replace invalid unmarshaler 373 config.RegisterUnmarshaler("yaml", nil) 374 respData, err := httpRequest(http.MethodGet, configURL, "") 375 // Adjust back to the correct unmarshaler 376 config.RegisterUnmarshaler("yaml", &config.YamlUnmarshaler{}) 377 if err != nil { 378 require.Nil(t, err, "httpGetBody failed") 379 return 380 } 381 382 err = json.Unmarshal(respData, &res) 383 require.Nil(t, err, "unmarshal failed", err) 384 require.Equal(t, errCodeServer, res.Errcode) 385 require.Equal(t, "cannot find yaml unmarshaler", res.Message) 386 }) 387 t.Run("failed to unmarshal configuration file", func(t *testing.T) { 388 // Replace invalid config path 389 s.config.configPath = "../testdata/greeter.trpc.go" 390 respData, err := httpRequest(http.MethodGet, configURL, "") 391 // Adjust back to the correct path 392 s.config.configPath = testConfigPath 393 require.Nil(t, err, "httpGetBody failed") 394 395 err = json.Unmarshal(respData, &res) 396 require.Nil(t, err, "unmarshal failed", err) 397 require.Equal(t, errCodeServer, res.Errcode) 398 }) 399 t.Run("right case", func(t *testing.T) { 400 time.Sleep(1 * time.Second) 401 respData, err := httpRequest(http.MethodGet, configURL, "") 402 require.Nil(t, err, "httpGetBody failed") 403 404 err = json.Unmarshal(respData, &res) 405 require.Nil(t, err, "unmarshal failed", err) 406 require.Equal(t, 0, res.Errcode) 407 require.NotNil(t, res.Content, "config content is empty") 408 }) 409 } 410 411 func TestCmdsHealthCheck(t *testing.T) { 412 s := newDefaultAdminServer() 413 mustStartAdminServer(t, s) 414 t.Cleanup(func() { 415 if err := s.Close(nil); err != nil { 416 t.Log(err) 417 } 418 }) 419 420 rsp, err := http.Get(fmt.Sprintf("http://%s/is_healthy", s.server.Addr)) 421 require.Nil(t, err) 422 require.Equal(t, http.StatusOK, rsp.StatusCode) 423 424 rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/", s.server.Addr)) 425 require.Nil(t, err) 426 require.Equal(t, http.StatusOK, rsp.StatusCode) 427 428 rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/not_exist", s.server.Addr)) 429 require.Nil(t, err) 430 require.Equal(t, http.StatusNotFound, rsp.StatusCode) 431 432 unregister, update, err := s.RegisterHealthCheck("service") 433 require.Nil(t, err) 434 rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy", s.server.Addr)) 435 require.Nil(t, err) 436 require.Equal(t, http.StatusNotFound, rsp.StatusCode) 437 rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/service", s.server.Addr)) 438 require.Nil(t, err) 439 require.Equal(t, http.StatusNotFound, rsp.StatusCode) 440 441 update(healthcheck.Serving) 442 rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy", s.server.Addr)) 443 require.Nil(t, err) 444 require.Equal(t, http.StatusOK, rsp.StatusCode) 445 rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/service", s.server.Addr)) 446 require.Nil(t, err) 447 require.Equal(t, http.StatusOK, rsp.StatusCode) 448 449 update(healthcheck.NotServing) 450 rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy", s.server.Addr)) 451 require.Nil(t, err) 452 require.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode) 453 rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/service", s.server.Addr)) 454 require.Nil(t, err) 455 require.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode) 456 457 unregister() 458 rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy", s.server.Addr)) 459 require.Nil(t, err) 460 require.Equal(t, http.StatusOK, rsp.StatusCode) 461 rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/service", s.server.Addr)) 462 require.Nil(t, err) 463 require.Equal(t, http.StatusNotFound, rsp.StatusCode) 464 } 465 466 func TestCmds(t *testing.T) { 467 s := newDefaultAdminServer() 468 mustStartAdminServer(t, s) 469 t.Cleanup(func() { 470 if err := s.Close(nil); err != nil { 471 t.Log(err) 472 } 473 }) 474 475 usercmdURL := fmt.Sprintf("http://%s", s.server.Addr) + "/cmds" 476 respData, err := httpRequest(http.MethodGet, usercmdURL, "") 477 require.Nil(t, err, "cmds request failed") 478 479 res := struct { 480 Errcode int `json:"errorcode"` 481 Message string `json:"message"` 482 Cmds []string `json:"cmds"` 483 }{} 484 err = json.Unmarshal(respData, &res) 485 require.Nil(t, err, "Unmarshal failed") 486 } 487 488 func TestErrorOutput(t *testing.T) { 489 s := newDefaultAdminServer() 490 mustStartAdminServer(t, s) 491 t.Cleanup(func() { 492 if err := s.Close(nil); err != nil { 493 t.Log(err) 494 } 495 }) 496 usercmdURL := fmt.Sprintf("http://%s", s.server.Addr) + "/errout" 497 respData, err := httpRequest(http.MethodGet, usercmdURL, "") 498 require.Nil(t, err, "cmds request failed") 499 500 res := struct { 501 Errcode int `json:"errorcode"` 502 Message string `json:"message"` 503 }{} 504 err = json.Unmarshal(respData, &res) 505 require.Nil(t, err, "Unmarshal failed") 506 require.Equal(t, 100, res.Errcode) 507 require.Contains(t, res.Message, "error") 508 } 509 510 func TestPanicHandle(t *testing.T) { 511 s := newDefaultAdminServer() 512 mustStartAdminServer(t, s) 513 t.Cleanup(func() { 514 if err := s.Close(nil); err != nil { 515 t.Log(err) 516 } 517 }) 518 519 usercmdURL := fmt.Sprintf("http://%s", s.server.Addr) + "/panicHandle" 520 respData, err := httpRequest(http.MethodGet, usercmdURL, "") 521 require.Nil(t, err, "cmds request failed") 522 523 res := struct { 524 Errcode int `json:"errorcode"` 525 Message string `json:"message"` 526 }{} 527 err = json.Unmarshal(respData, &res) 528 require.Nil(t, err, "Unmarshal failed") 529 require.Equal(t, 500, res.Errcode) 530 require.Contains(t, res.Message, "panic") 531 } 532 533 func TestListen(t *testing.T) { 534 s := NewServer() 535 536 // listen fail on invalid address 537 err := os.Setenv(transport.EnvGraceRestart, "0") 538 assert.Nil(t, err) 539 ln, err := s.listen("tcp", "invalid address") 540 assert.NotNil(t, err) 541 assert.Nil(t, ln) 542 543 // listen success 544 ln, err = s.listen("tcp", "127.0.0.1:0") 545 assert.Nil(t, err) 546 assert.NotNil(t, ln) 547 defer func(ln net.Listener) { 548 assert.Nil(t, ln.Close()) 549 }(ln) 550 assert.IsType(t, &net.TCPListener{}, ln) 551 } 552 553 func TestClose(t *testing.T) { 554 s := newDefaultAdminServer() 555 mustStartAdminServer(t, s) 556 557 err := s.Close(nil) 558 require.Nil(t, err) 559 560 usercmdURL := fmt.Sprintf("http://%s/cmds", s.server.Addr) 561 _, err = httpRequest(http.MethodGet, usercmdURL, "") 562 var netErr *net.OpError 563 564 require.ErrorAs(t, err, &netErr) 565 } 566 567 func TestOptionsConfig(t *testing.T) { 568 s := newDefaultAdminServer() 569 WithTLS(true)(s.config) 570 err := s.Serve() 571 require.NotNil(t, err) 572 require.Contains(t, err.Error(), "not support") 573 } 574 575 func httpRequest(method string, url string, body string) ([]byte, error) { 576 request, err := http.NewRequest(method, url, strings.NewReader(body)) 577 request.Header.Set("content-type", "application/x-www-form-urlencoded") 578 if err != nil { 579 return nil, err 580 } 581 582 response, err := http.DefaultClient.Do(request) 583 if err != nil { 584 return nil, err 585 } 586 defer response.Body.Close() 587 return io.ReadAll(response.Body) 588 } 589 590 func userCmd(w http.ResponseWriter, r *http.Request) { 591 _, _ = w.Write([]byte("usercmd")) 592 } 593 594 func errOutput(w http.ResponseWriter, r *http.Request) { 595 ErrorOutput(w, "error output", 100) 596 } 597 598 func panicHandle(w http.ResponseWriter, r *http.Request) { 599 panic("panic error handle") 600 } 601 602 func TestUnregisterHandlers(t *testing.T) { 603 _ = newDefaultAdminServer() 604 mux, err := extractServeMuxData() 605 require.Nil(t, err) 606 require.Len(t, mux.m, 0) 607 require.Len(t, mux.es, 0) 608 require.False(t, mux.hosts) 609 610 http.HandleFunc("/usercmd", userCmd) 611 http.HandleFunc("/errout", errOutput) 612 http.HandleFunc("/panicHandle", panicHandle) 613 http.HandleFunc("www.qq.com/", userCmd) 614 http.HandleFunc("anything/", userCmd) 615 616 l := mustListenTCP(t) 617 go func() { 618 if err := http.Serve(l, nil); err != nil { 619 t.Log(err) 620 } 621 }() 622 time.Sleep(200 * time.Millisecond) 623 624 mux, err = extractServeMuxData() 625 require.Nil(t, err) 626 require.Equal(t, 5, len(mux.m)) 627 require.Equal(t, 2, len(mux.es)) 628 require.Equal(t, true, mux.hosts) 629 630 err = unregisterHandlers( 631 []string{ 632 "/usercmd", 633 "/errout", 634 "/panicHandle", 635 "www.qq.com/", 636 "anything/", 637 }, 638 ) 639 require.Nil(t, err) 640 641 mux, err = extractServeMuxData() 642 require.Nil(t, err) 643 require.Len(t, mux.m, 0) 644 require.Len(t, mux.es, 0) 645 require.False(t, mux.hosts) 646 647 resp1, err := http.Get(fmt.Sprintf("http://%v/usercmd", l.Addr())) 648 require.Nil(t, err) 649 defer resp1.Body.Close() 650 require.Equal(t, http.StatusNotFound, resp1.StatusCode) 651 652 http.HandleFunc("/usercmd", userCmd) 653 http.HandleFunc("/errout", errOutput) 654 http.HandleFunc("/panicHandle", panicHandle) 655 656 mux, err = extractServeMuxData() 657 require.Nil(t, err) 658 require.Len(t, mux.m, 3) 659 require.Len(t, mux.es, 0) 660 require.False(t, mux.hosts) 661 662 resp2, err := http.Get(fmt.Sprintf("http://%v/usercmd", l.Addr())) 663 require.Nil(t, err) 664 defer resp2.Body.Close() 665 respBody, err := io.ReadAll(resp2.Body) 666 require.Nil(t, err) 667 require.Equal(t, []byte("usercmd"), respBody) 668 } 669 func mustListenTCP(t *testing.T) *net.TCPListener { 670 l, err := net.Listen("tcp", testAddress) 671 if err != nil { 672 t.Fatal(err) 673 } 674 return l.(*net.TCPListener) 675 } 676 677 // serveMux keep the same structure with http.ServeMux 678 type serveMux struct { 679 m map[string]muxEntry 680 es []muxEntry 681 hosts bool 682 } 683 684 // muxEntry keep the same structure with muxEntry in net/http pkg 685 type muxEntry struct { 686 } 687 688 // extractServeMuxData get http.DefaultServeMux 's data and show 689 func extractServeMuxData() (*serveMux, error) { 690 v := reflect.ValueOf(http.DefaultServeMux) 691 692 // lock 693 muField := v.Elem().FieldByName("mu") 694 if !muField.IsValid() { 695 return nil, errors.New("http.DefaultServeMux does not have a field called `mu`") 696 } 697 muPointer := unsafe.Pointer(muField.UnsafeAddr()) 698 mu := (*sync.RWMutex)(muPointer) 699 (*mu).Lock() 700 defer (*mu).Unlock() 701 702 // get value of map 703 mField := v.Elem().FieldByName("m") 704 if !mField.IsValid() { 705 return nil, errors.New("http.DefaultServeMux does not have a field called `m`") 706 } 707 mPointer := unsafe.Pointer(mField.UnsafeAddr()) 708 m := (*map[string]muxEntry)(mPointer) 709 710 // get value of slice 711 esField := v.Elem().FieldByName("es") 712 if !esField.IsValid() { 713 return nil, errors.New("http.DefaultServeMux does not have a field called `es`") 714 } 715 esPointer := unsafe.Pointer(esField.UnsafeAddr()) 716 es := (*[]muxEntry)(esPointer) 717 718 // get hosts 719 hostsField := v.Elem().FieldByName("hosts") 720 if !hostsField.IsValid() { 721 return nil, errors.New("http.DefaultServeMux does not have a field called `hosts`") 722 } 723 hostsPointer := unsafe.Pointer(hostsField.UnsafeAddr()) 724 hosts := (*bool)(hostsPointer) 725 726 return &serveMux{ 727 m: *m, 728 es: *es, 729 hosts: *hosts, 730 }, nil 731 } 732 733 func TestTrpcAdminServer(t *testing.T) { 734 s := NewServer(WithAddr("invalid addr")) 735 err := s.Serve() 736 require.NotNil(t, err) 737 738 s = NewServer(WithAddr(testAddress)) 739 err = s.Register(struct{}{}, struct{}{}) 740 require.Nil(t, err) 741 742 go func() { 743 if err := s.Serve(); err != nil { 744 t.Log(err) 745 } 746 }() 747 time.Sleep(200 * time.Millisecond) 748 749 ch := make(chan struct{}, 1) 750 err = s.Close(ch) 751 closed := <-ch 752 require.NotNil(t, closed) 753 require.Nil(t, err) 754 }