trpc.group/trpc-go/trpc-go@v1.0.3/admin/admin_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package admin
    15  
    16  import (
    17  	"context"
    18  	"encoding/json"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"net"
    23  	"net/http"
    24  	"os"
    25  	"reflect"
    26  	"strings"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  	"unsafe"
    31  
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  
    35  	"trpc.group/trpc-go/trpc-go/config"
    36  	"trpc.group/trpc-go/trpc-go/healthcheck"
    37  	"trpc.group/trpc-go/trpc-go/log"
    38  	"trpc.group/trpc-go/trpc-go/rpcz"
    39  	"trpc.group/trpc-go/trpc-go/transport"
    40  )
    41  
    42  const (
    43  	testVersion    = "v0.2.0-alpha"
    44  	testAddress    = "localhost:0"
    45  	testConfigPath = "../testdata/trpc_go.yaml"
    46  )
    47  
    48  func newDefaultAdminServer() *Server {
    49  	s := NewServer(
    50  		WithVersion(testVersion),
    51  		WithAddr(testAddress),
    52  		WithTLS(false),
    53  		WithReadTimeout(defaultReadTimeout),
    54  		WithWriteTimeout(defaultWriteTimeout),
    55  		WithConfigPath(testConfigPath),
    56  	)
    57  
    58  	s.HandleFunc("/usercmd", userCmd)
    59  	s.HandleFunc("/errout", errOutput)
    60  	s.HandleFunc("/panicHandle", panicHandle)
    61  
    62  	return s
    63  }
    64  
    65  func mustStartAdminServer(t *testing.T, s *Server) {
    66  	t.Helper()
    67  
    68  	go func() {
    69  		if err := s.Serve(); err != nil {
    70  			t.Log(err)
    71  		}
    72  	}()
    73  	time.Sleep(200 * time.Millisecond)
    74  }
    75  
    76  func TestRPCZFailed(t *testing.T) {
    77  	s := newDefaultAdminServer()
    78  	mustStartAdminServer(t, s)
    79  	t.Cleanup(func() {
    80  		if err := s.Close(nil); err != nil {
    81  			t.Log(err)
    82  		}
    83  	})
    84  	tests := []struct {
    85  		name      string
    86  		url       string
    87  		errorCode int
    88  		message   string
    89  		content   interface{}
    90  	}{
    91  		{
    92  			name:      "handleSpans failed because query parameter isn't a number",
    93  			url:       fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpansList + "?num=xxx",
    94  			errorCode: errCodeServer,
    95  			message:   "must be a integer",
    96  			content:   "",
    97  		},
    98  		{
    99  			name:      "handleSpans failed because query parameter isn't a positive integer",
   100  			url:       fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpansList + "?num=-1",
   101  			errorCode: errCodeServer,
   102  			message:   "must be a non-negative integer",
   103  			content:   nil,
   104  		},
   105  		{
   106  			name:      "handleSpan failed because can't find span_id",
   107  			url:       fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpanGet + "1",
   108  			errorCode: errCodeServer,
   109  			message:   "cannot find span-id",
   110  			content:   nil,
   111  		},
   112  		{
   113  			name:      "handleSpan failed because query parameter span_id is empty",
   114  			url:       fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpanGet + "",
   115  			errorCode: errCodeServer,
   116  			message:   "undefined command",
   117  			content:   nil,
   118  		},
   119  		{
   120  			name:      "handleSpan failed because query parameter span_id is negative",
   121  			url:       fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpanGet + "-1",
   122  			errorCode: errCodeServer,
   123  			message:   "can not be negative",
   124  			content:   nil,
   125  		},
   126  	}
   127  	for _, tt := range tests {
   128  		t.Run(tt.name, func(t *testing.T) {
   129  			r, err := httpRequest(http.MethodGet, tt.url, "")
   130  			require.Nil(t, err)
   131  			require.Contains(t, string(r), tt.message)
   132  		})
   133  	}
   134  	t.Run("url query doesn't match rpcz", func(t *testing.T) {
   135  		r, err := httpRequest(http.MethodGet, fmt.Sprintf("http://%s", s.server.Addr)+"/cmd/rpcz", "")
   136  		require.Nil(t, err)
   137  		require.Contains(t, string(r), "404 page not found")
   138  	})
   139  }
   140  
   141  type sliceSpanExporter struct {
   142  	spans []rpcz.ReadOnlySpan
   143  }
   144  
   145  func (e *sliceSpanExporter) Export(span *rpcz.ReadOnlySpan) {
   146  	e.spans = append(e.spans, *span)
   147  }
   148  
   149  func TestRPC_Exporter(t *testing.T) {
   150  	s := newDefaultAdminServer()
   151  	mustStartAdminServer(t, s)
   152  	t.Cleanup(func() {
   153  		if err := s.Close(nil); err != nil {
   154  			t.Log(err)
   155  		}
   156  	})
   157  	oldGlobalRPCZ := rpcz.GlobalRPCZ
   158  	defer func() {
   159  		rpcz.GlobalRPCZ = oldGlobalRPCZ
   160  	}()
   161  	// Given a GlobalRPCZ configured with exporter
   162  	exporter := &sliceSpanExporter{}
   163  	rpcz.GlobalRPCZ = rpcz.NewRPCZ(&rpcz.Config{Fraction: 1.0, Capacity: 10, Exporter: exporter})
   164  
   165  	// When End a "server" span with spanID.
   166  	span := rpcz.SpanFromContext(context.Background())
   167  	cs, end := span.NewChild("server")
   168  	spanID := cs.ID()
   169  	end.End()
   170  
   171  	// Then the exporter contain the span exported by the GlobalRPCZ
   172  	require.Len(t, exporter.spans, 1)
   173  	require.Equal(t, spanID, exporter.spans[0].ID)
   174  
   175  	// And the GlobalRPCZ still stores a copy of the exported span
   176  	rRaw, err := httpRequest(http.MethodGet, fmt.Sprintf("http://%s", s.server.Addr)+patternRPCZSpansList+"?num", "")
   177  	require.Nil(t, err)
   178  	require.Contains(t, string(rRaw), fmt.Sprint(spanID))
   179  }
   180  
   181  func TestRPCZOk(t *testing.T) {
   182  	s := newDefaultAdminServer()
   183  	mustStartAdminServer(t, s)
   184  	t.Cleanup(func() {
   185  		if err := s.Close(nil); err != nil {
   186  			t.Log(err)
   187  		}
   188  	})
   189  	oldGlobalRPCZ := rpcz.GlobalRPCZ
   190  	defer func() {
   191  		rpcz.GlobalRPCZ = oldGlobalRPCZ
   192  	}()
   193  	rpcz.GlobalRPCZ = rpcz.NewRPCZ(&rpcz.Config{Fraction: 1.0, Capacity: 10})
   194  	span := rpcz.SpanFromContext(context.Background())
   195  
   196  	cs, end := span.NewChild("server")
   197  	spanID := cs.ID()
   198  	end.End()
   199  
   200  	tests := []struct {
   201  		name      string
   202  		url       string
   203  		errorCode int
   204  		message   string
   205  		content   interface{}
   206  	}{
   207  		{
   208  			name:    "handleSpans ok query parameter num is empty",
   209  			url:     fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpansList + "?num",
   210  			content: fmt.Sprintf("1:\n  span: (server, %d)\n", spanID),
   211  		},
   212  		{
   213  			name:    "handleSpans ok without any query parameter",
   214  			url:     fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpansList,
   215  			content: fmt.Sprintf("1:\n  span: (server, %d)\n", spanID),
   216  		},
   217  		{
   218  			name:    "handleSpans ok",
   219  			url:     fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpansList + "?num=1",
   220  			content: fmt.Sprintf("1:\n  span: (server, %d)\n", spanID),
   221  		},
   222  		{
   223  			name:    "handleSpan ok",
   224  			url:     fmt.Sprintf("http://%s", s.server.Addr) + patternRPCZSpanGet + fmt.Sprint(spanID),
   225  			content: fmt.Sprintf("span: (server, %d)\n", spanID),
   226  		},
   227  	}
   228  	for _, tt := range tests {
   229  		t.Run(tt.name, func(t *testing.T) {
   230  			rRaw, err := httpRequest(http.MethodGet, tt.url, "")
   231  			r := string(rRaw)
   232  			require.Nil(t, err)
   233  			require.Contains(t, r, tt.message)
   234  			require.Contains(t, r, tt.content)
   235  
   236  		})
   237  	}
   238  }
   239  
   240  func TestCmdVersion(t *testing.T) {
   241  	s := newDefaultAdminServer()
   242  	mustStartAdminServer(t, s)
   243  	t.Cleanup(func() {
   244  		if err := s.Close(nil); err != nil {
   245  			t.Log(err)
   246  		}
   247  	})
   248  	versionURL := fmt.Sprintf("http://%s", s.server.Addr) + "/version"
   249  	respData, err := httpRequest(http.MethodGet, versionURL, "")
   250  	if err != nil {
   251  		require.Nil(t, err, "httpGetBody failed")
   252  		return
   253  	}
   254  
   255  	res := struct {
   256  		Errcode int    `json:"errorcode"`
   257  		Message string `json:"message"`
   258  		Version string `json:"version"`
   259  	}{}
   260  	err = json.Unmarshal(respData, &res)
   261  	require.Nil(t, err, "testAdminServerVersion unmarshal failed")
   262  	require.Equal(t, 0, res.Errcode)
   263  	require.Equal(t, testVersion, res.Version)
   264  }
   265  
   266  func TestCmdsLogLevel(t *testing.T) {
   267  	s := newDefaultAdminServer()
   268  	mustStartAdminServer(t, s)
   269  	t.Cleanup(func() {
   270  		if err := s.Close(nil); err != nil {
   271  			t.Log(err)
   272  		}
   273  	})
   274  
   275  	dlogger := log.GetDefaultLogger()
   276  
   277  	// Preset test conditions
   278  	log.Register("default", log.NewZapLog([]log.OutputConfig{
   279  		{Writer: log.OutputConsole, Level: "debug"},
   280  		{Writer: log.OutputFile, WriteConfig: log.WriteConfig{Filename: "test"}, Level: "info"},
   281  	}))
   282  
   283  	t.Cleanup(func() {
   284  		log.Register("default", dlogger)
   285  	})
   286  
   287  	res := struct {
   288  		Errcode  int    `json:"errorcode"`
   289  		Message  string `json:"message"`
   290  		Level    string `json:"level"`
   291  		PreLevel string `json:"prelevel"`
   292  	}{}
   293  
   294  	t.Run("right case", func(t *testing.T) {
   295  		logURL := fmt.Sprintf("http://%s", s.server.Addr) + "/cmds/loglevel?logger=default&output=1"
   296  		// TestGet
   297  		respData, err := httpRequest(http.MethodGet, logURL, "")
   298  		require.Nil(t, err, "httpGetBody failed")
   299  
   300  		err = json.Unmarshal(respData, &res)
   301  		require.Nil(t, err, "testAdminServerLogLevel unmarshal failed")
   302  		require.Equal(t, 0, res.Errcode)
   303  		require.Equal(t, "info", res.Level)
   304  
   305  		// TestUpdate
   306  		body, err := httpRequest(http.MethodPut, logURL, "value=debug")
   307  		require.Nil(t, err, "httpRequest failed:", err)
   308  		err = json.Unmarshal(body, &res)
   309  		require.Nil(t, err, "Unmarshal failed:", err)
   310  		require.Equal(t, 0, res.Errcode)
   311  		require.Equal(t, "info", res.PreLevel)
   312  		require.Equal(t, "debug", res.Level)
   313  	})
   314  	t.Run("request parameter is empty", func(t *testing.T) {
   315  		logURL := fmt.Sprintf("http://%s", s.server.Addr) + "/cmds/loglevel"
   316  		respData, err := httpRequest(http.MethodGet, logURL, "")
   317  		require.Nil(t, err, "httpGetBody failed")
   318  
   319  		err = json.Unmarshal(respData, &res)
   320  		require.Nil(t, err, "unmarshal failed")
   321  		require.Equal(t, 0, res.Errcode)
   322  		require.Equal(t, "debug", res.Level)
   323  	})
   324  	t.Run("failed to parse request parameters", func(t *testing.T) {
   325  		logURL := fmt.Sprintf("http://%s", s.server.Addr) + "/cmds/loglevel?logger%"
   326  		respData, err := httpRequest(http.MethodGet, logURL, "")
   327  		require.Nil(t, err, "httpGetBody failed:", err)
   328  
   329  		err = json.Unmarshal(respData, &res)
   330  		require.Nil(t, err, "Unmarshal failed", err)
   331  		require.Equal(t, errCodeServer, res.Errcode)
   332  	})
   333  	t.Run("logger is invalid", func(t *testing.T) {
   334  		logURL := fmt.Sprintf("http://%s", s.server.Addr) + "/cmds/loglevel?logger=invalid"
   335  		respData, err := httpRequest(http.MethodGet, logURL, "")
   336  		require.Nil(t, err, "httpGetBody failed:", err)
   337  
   338  		err = json.Unmarshal(respData, &res)
   339  		require.Nil(t, err, "Unmarshal failed", err)
   340  		require.Equal(t, errCodeServer, res.Errcode)
   341  		require.Equal(t, "logger invalid not found", res.Message)
   342  	})
   343  }
   344  
   345  func TestCmdsConfig(t *testing.T) {
   346  	s := newDefaultAdminServer()
   347  	mustStartAdminServer(t, s)
   348  	t.Cleanup(func() {
   349  		if err := s.Close(nil); err != nil {
   350  			t.Log(err)
   351  		}
   352  	})
   353  	configURL := fmt.Sprintf("http://%s//cmds/config", s.server.Addr)
   354  	res := struct {
   355  		Errcode int         `json:"errorcode"`
   356  		Message string      `json:"message"`
   357  		Content interface{} `json:"content"`
   358  	}{}
   359  	t.Run("failed to read configuration file", func(t *testing.T) {
   360  		// Replace invalid config path
   361  		s.config.configPath = "./invalid/invalid.yaml"
   362  		respData, err := httpRequest(http.MethodGet, configURL, "")
   363  		// Adjust back to the correct path
   364  		s.config.configPath = testConfigPath
   365  		require.Nil(t, err, "httpGetBody failed")
   366  
   367  		err = json.Unmarshal(respData, &res)
   368  		require.Nil(t, err, "unmarshal failed", err)
   369  		require.Equal(t, errCodeServer, res.Errcode)
   370  	})
   371  	t.Run("failed to get unmarshaler", func(t *testing.T) {
   372  		// Replace invalid unmarshaler
   373  		config.RegisterUnmarshaler("yaml", nil)
   374  		respData, err := httpRequest(http.MethodGet, configURL, "")
   375  		// Adjust back to the correct unmarshaler
   376  		config.RegisterUnmarshaler("yaml", &config.YamlUnmarshaler{})
   377  		if err != nil {
   378  			require.Nil(t, err, "httpGetBody failed")
   379  			return
   380  		}
   381  
   382  		err = json.Unmarshal(respData, &res)
   383  		require.Nil(t, err, "unmarshal failed", err)
   384  		require.Equal(t, errCodeServer, res.Errcode)
   385  		require.Equal(t, "cannot find yaml unmarshaler", res.Message)
   386  	})
   387  	t.Run("failed to unmarshal configuration file", func(t *testing.T) {
   388  		// Replace invalid config path
   389  		s.config.configPath = "../testdata/greeter.trpc.go"
   390  		respData, err := httpRequest(http.MethodGet, configURL, "")
   391  		// Adjust back to the correct path
   392  		s.config.configPath = testConfigPath
   393  		require.Nil(t, err, "httpGetBody failed")
   394  
   395  		err = json.Unmarshal(respData, &res)
   396  		require.Nil(t, err, "unmarshal failed", err)
   397  		require.Equal(t, errCodeServer, res.Errcode)
   398  	})
   399  	t.Run("right case", func(t *testing.T) {
   400  		time.Sleep(1 * time.Second)
   401  		respData, err := httpRequest(http.MethodGet, configURL, "")
   402  		require.Nil(t, err, "httpGetBody failed")
   403  
   404  		err = json.Unmarshal(respData, &res)
   405  		require.Nil(t, err, "unmarshal failed", err)
   406  		require.Equal(t, 0, res.Errcode)
   407  		require.NotNil(t, res.Content, "config content is empty")
   408  	})
   409  }
   410  
   411  func TestCmdsHealthCheck(t *testing.T) {
   412  	s := newDefaultAdminServer()
   413  	mustStartAdminServer(t, s)
   414  	t.Cleanup(func() {
   415  		if err := s.Close(nil); err != nil {
   416  			t.Log(err)
   417  		}
   418  	})
   419  
   420  	rsp, err := http.Get(fmt.Sprintf("http://%s/is_healthy", s.server.Addr))
   421  	require.Nil(t, err)
   422  	require.Equal(t, http.StatusOK, rsp.StatusCode)
   423  
   424  	rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/", s.server.Addr))
   425  	require.Nil(t, err)
   426  	require.Equal(t, http.StatusOK, rsp.StatusCode)
   427  
   428  	rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/not_exist", s.server.Addr))
   429  	require.Nil(t, err)
   430  	require.Equal(t, http.StatusNotFound, rsp.StatusCode)
   431  
   432  	unregister, update, err := s.RegisterHealthCheck("service")
   433  	require.Nil(t, err)
   434  	rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy", s.server.Addr))
   435  	require.Nil(t, err)
   436  	require.Equal(t, http.StatusNotFound, rsp.StatusCode)
   437  	rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/service", s.server.Addr))
   438  	require.Nil(t, err)
   439  	require.Equal(t, http.StatusNotFound, rsp.StatusCode)
   440  
   441  	update(healthcheck.Serving)
   442  	rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy", s.server.Addr))
   443  	require.Nil(t, err)
   444  	require.Equal(t, http.StatusOK, rsp.StatusCode)
   445  	rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/service", s.server.Addr))
   446  	require.Nil(t, err)
   447  	require.Equal(t, http.StatusOK, rsp.StatusCode)
   448  
   449  	update(healthcheck.NotServing)
   450  	rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy", s.server.Addr))
   451  	require.Nil(t, err)
   452  	require.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode)
   453  	rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/service", s.server.Addr))
   454  	require.Nil(t, err)
   455  	require.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode)
   456  
   457  	unregister()
   458  	rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy", s.server.Addr))
   459  	require.Nil(t, err)
   460  	require.Equal(t, http.StatusOK, rsp.StatusCode)
   461  	rsp, err = http.Get(fmt.Sprintf("http://%s/is_healthy/service", s.server.Addr))
   462  	require.Nil(t, err)
   463  	require.Equal(t, http.StatusNotFound, rsp.StatusCode)
   464  }
   465  
   466  func TestCmds(t *testing.T) {
   467  	s := newDefaultAdminServer()
   468  	mustStartAdminServer(t, s)
   469  	t.Cleanup(func() {
   470  		if err := s.Close(nil); err != nil {
   471  			t.Log(err)
   472  		}
   473  	})
   474  
   475  	usercmdURL := fmt.Sprintf("http://%s", s.server.Addr) + "/cmds"
   476  	respData, err := httpRequest(http.MethodGet, usercmdURL, "")
   477  	require.Nil(t, err, "cmds request failed")
   478  
   479  	res := struct {
   480  		Errcode int      `json:"errorcode"`
   481  		Message string   `json:"message"`
   482  		Cmds    []string `json:"cmds"`
   483  	}{}
   484  	err = json.Unmarshal(respData, &res)
   485  	require.Nil(t, err, "Unmarshal failed")
   486  }
   487  
   488  func TestErrorOutput(t *testing.T) {
   489  	s := newDefaultAdminServer()
   490  	mustStartAdminServer(t, s)
   491  	t.Cleanup(func() {
   492  		if err := s.Close(nil); err != nil {
   493  			t.Log(err)
   494  		}
   495  	})
   496  	usercmdURL := fmt.Sprintf("http://%s", s.server.Addr) + "/errout"
   497  	respData, err := httpRequest(http.MethodGet, usercmdURL, "")
   498  	require.Nil(t, err, "cmds request failed")
   499  
   500  	res := struct {
   501  		Errcode int    `json:"errorcode"`
   502  		Message string `json:"message"`
   503  	}{}
   504  	err = json.Unmarshal(respData, &res)
   505  	require.Nil(t, err, "Unmarshal failed")
   506  	require.Equal(t, 100, res.Errcode)
   507  	require.Contains(t, res.Message, "error")
   508  }
   509  
   510  func TestPanicHandle(t *testing.T) {
   511  	s := newDefaultAdminServer()
   512  	mustStartAdminServer(t, s)
   513  	t.Cleanup(func() {
   514  		if err := s.Close(nil); err != nil {
   515  			t.Log(err)
   516  		}
   517  	})
   518  
   519  	usercmdURL := fmt.Sprintf("http://%s", s.server.Addr) + "/panicHandle"
   520  	respData, err := httpRequest(http.MethodGet, usercmdURL, "")
   521  	require.Nil(t, err, "cmds request failed")
   522  
   523  	res := struct {
   524  		Errcode int    `json:"errorcode"`
   525  		Message string `json:"message"`
   526  	}{}
   527  	err = json.Unmarshal(respData, &res)
   528  	require.Nil(t, err, "Unmarshal failed")
   529  	require.Equal(t, 500, res.Errcode)
   530  	require.Contains(t, res.Message, "panic")
   531  }
   532  
   533  func TestListen(t *testing.T) {
   534  	s := NewServer()
   535  
   536  	// listen fail on invalid address
   537  	err := os.Setenv(transport.EnvGraceRestart, "0")
   538  	assert.Nil(t, err)
   539  	ln, err := s.listen("tcp", "invalid address")
   540  	assert.NotNil(t, err)
   541  	assert.Nil(t, ln)
   542  
   543  	// listen success
   544  	ln, err = s.listen("tcp", "127.0.0.1:0")
   545  	assert.Nil(t, err)
   546  	assert.NotNil(t, ln)
   547  	defer func(ln net.Listener) {
   548  		assert.Nil(t, ln.Close())
   549  	}(ln)
   550  	assert.IsType(t, &net.TCPListener{}, ln)
   551  }
   552  
   553  func TestClose(t *testing.T) {
   554  	s := newDefaultAdminServer()
   555  	mustStartAdminServer(t, s)
   556  
   557  	err := s.Close(nil)
   558  	require.Nil(t, err)
   559  
   560  	usercmdURL := fmt.Sprintf("http://%s/cmds", s.server.Addr)
   561  	_, err = httpRequest(http.MethodGet, usercmdURL, "")
   562  	var netErr *net.OpError
   563  
   564  	require.ErrorAs(t, err, &netErr)
   565  }
   566  
   567  func TestOptionsConfig(t *testing.T) {
   568  	s := newDefaultAdminServer()
   569  	WithTLS(true)(s.config)
   570  	err := s.Serve()
   571  	require.NotNil(t, err)
   572  	require.Contains(t, err.Error(), "not support")
   573  }
   574  
   575  func httpRequest(method string, url string, body string) ([]byte, error) {
   576  	request, err := http.NewRequest(method, url, strings.NewReader(body))
   577  	request.Header.Set("content-type", "application/x-www-form-urlencoded")
   578  	if err != nil {
   579  		return nil, err
   580  	}
   581  
   582  	response, err := http.DefaultClient.Do(request)
   583  	if err != nil {
   584  		return nil, err
   585  	}
   586  	defer response.Body.Close()
   587  	return io.ReadAll(response.Body)
   588  }
   589  
   590  func userCmd(w http.ResponseWriter, r *http.Request) {
   591  	_, _ = w.Write([]byte("usercmd"))
   592  }
   593  
   594  func errOutput(w http.ResponseWriter, r *http.Request) {
   595  	ErrorOutput(w, "error output", 100)
   596  }
   597  
   598  func panicHandle(w http.ResponseWriter, r *http.Request) {
   599  	panic("panic error handle")
   600  }
   601  
   602  func TestUnregisterHandlers(t *testing.T) {
   603  	_ = newDefaultAdminServer()
   604  	mux, err := extractServeMuxData()
   605  	require.Nil(t, err)
   606  	require.Len(t, mux.m, 0)
   607  	require.Len(t, mux.es, 0)
   608  	require.False(t, mux.hosts)
   609  
   610  	http.HandleFunc("/usercmd", userCmd)
   611  	http.HandleFunc("/errout", errOutput)
   612  	http.HandleFunc("/panicHandle", panicHandle)
   613  	http.HandleFunc("www.qq.com/", userCmd)
   614  	http.HandleFunc("anything/", userCmd)
   615  
   616  	l := mustListenTCP(t)
   617  	go func() {
   618  		if err := http.Serve(l, nil); err != nil {
   619  			t.Log(err)
   620  		}
   621  	}()
   622  	time.Sleep(200 * time.Millisecond)
   623  
   624  	mux, err = extractServeMuxData()
   625  	require.Nil(t, err)
   626  	require.Equal(t, 5, len(mux.m))
   627  	require.Equal(t, 2, len(mux.es))
   628  	require.Equal(t, true, mux.hosts)
   629  
   630  	err = unregisterHandlers(
   631  		[]string{
   632  			"/usercmd",
   633  			"/errout",
   634  			"/panicHandle",
   635  			"www.qq.com/",
   636  			"anything/",
   637  		},
   638  	)
   639  	require.Nil(t, err)
   640  
   641  	mux, err = extractServeMuxData()
   642  	require.Nil(t, err)
   643  	require.Len(t, mux.m, 0)
   644  	require.Len(t, mux.es, 0)
   645  	require.False(t, mux.hosts)
   646  
   647  	resp1, err := http.Get(fmt.Sprintf("http://%v/usercmd", l.Addr()))
   648  	require.Nil(t, err)
   649  	defer resp1.Body.Close()
   650  	require.Equal(t, http.StatusNotFound, resp1.StatusCode)
   651  
   652  	http.HandleFunc("/usercmd", userCmd)
   653  	http.HandleFunc("/errout", errOutput)
   654  	http.HandleFunc("/panicHandle", panicHandle)
   655  
   656  	mux, err = extractServeMuxData()
   657  	require.Nil(t, err)
   658  	require.Len(t, mux.m, 3)
   659  	require.Len(t, mux.es, 0)
   660  	require.False(t, mux.hosts)
   661  
   662  	resp2, err := http.Get(fmt.Sprintf("http://%v/usercmd", l.Addr()))
   663  	require.Nil(t, err)
   664  	defer resp2.Body.Close()
   665  	respBody, err := io.ReadAll(resp2.Body)
   666  	require.Nil(t, err)
   667  	require.Equal(t, []byte("usercmd"), respBody)
   668  }
   669  func mustListenTCP(t *testing.T) *net.TCPListener {
   670  	l, err := net.Listen("tcp", testAddress)
   671  	if err != nil {
   672  		t.Fatal(err)
   673  	}
   674  	return l.(*net.TCPListener)
   675  }
   676  
   677  // serveMux keep the same structure with http.ServeMux
   678  type serveMux struct {
   679  	m     map[string]muxEntry
   680  	es    []muxEntry
   681  	hosts bool
   682  }
   683  
   684  // muxEntry keep the same structure with muxEntry in net/http pkg
   685  type muxEntry struct {
   686  }
   687  
   688  // extractServeMuxData get http.DefaultServeMux 's data and show
   689  func extractServeMuxData() (*serveMux, error) {
   690  	v := reflect.ValueOf(http.DefaultServeMux)
   691  
   692  	// lock
   693  	muField := v.Elem().FieldByName("mu")
   694  	if !muField.IsValid() {
   695  		return nil, errors.New("http.DefaultServeMux does not have a field called `mu`")
   696  	}
   697  	muPointer := unsafe.Pointer(muField.UnsafeAddr())
   698  	mu := (*sync.RWMutex)(muPointer)
   699  	(*mu).Lock()
   700  	defer (*mu).Unlock()
   701  
   702  	// get value of map
   703  	mField := v.Elem().FieldByName("m")
   704  	if !mField.IsValid() {
   705  		return nil, errors.New("http.DefaultServeMux does not have a field called `m`")
   706  	}
   707  	mPointer := unsafe.Pointer(mField.UnsafeAddr())
   708  	m := (*map[string]muxEntry)(mPointer)
   709  
   710  	// get value of slice
   711  	esField := v.Elem().FieldByName("es")
   712  	if !esField.IsValid() {
   713  		return nil, errors.New("http.DefaultServeMux does not have a field called `es`")
   714  	}
   715  	esPointer := unsafe.Pointer(esField.UnsafeAddr())
   716  	es := (*[]muxEntry)(esPointer)
   717  
   718  	// get hosts
   719  	hostsField := v.Elem().FieldByName("hosts")
   720  	if !hostsField.IsValid() {
   721  		return nil, errors.New("http.DefaultServeMux does not have a field called `hosts`")
   722  	}
   723  	hostsPointer := unsafe.Pointer(hostsField.UnsafeAddr())
   724  	hosts := (*bool)(hostsPointer)
   725  
   726  	return &serveMux{
   727  		m:     *m,
   728  		es:    *es,
   729  		hosts: *hosts,
   730  	}, nil
   731  }
   732  
   733  func TestTrpcAdminServer(t *testing.T) {
   734  	s := NewServer(WithAddr("invalid addr"))
   735  	err := s.Serve()
   736  	require.NotNil(t, err)
   737  
   738  	s = NewServer(WithAddr(testAddress))
   739  	err = s.Register(struct{}{}, struct{}{})
   740  	require.Nil(t, err)
   741  
   742  	go func() {
   743  		if err := s.Serve(); err != nil {
   744  			t.Log(err)
   745  		}
   746  	}()
   747  	time.Sleep(200 * time.Millisecond)
   748  
   749  	ch := make(chan struct{}, 1)
   750  	err = s.Close(ch)
   751  	closed := <-ch
   752  	require.NotNil(t, closed)
   753  	require.Nil(t, err)
   754  }