trpc.group/trpc-go/trpc-go@v1.0.3/admin/mux.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package admin
    15  
    16  import (
    17  	"errors"
    18  	"net/http"
    19  	"reflect"
    20  	"sync"
    21  	"unsafe"
    22  )
    23  
    24  // unregisterHandlers deletes router from http.DefaultServeMux.
    25  // The import of "net/http/pprof" will automatically register pprof related routes on
    26  // http.DefaultServeMux, which may cause security problems.
    27  // Refer to:https://github.com/golang/go/issues/22085
    28  func unregisterHandlers(patterns []string) error {
    29  	// Need to import muxEntry in net/http pkg.
    30  	type muxEntry struct {
    31  		h       http.Handler
    32  		pattern string
    33  	}
    34  
    35  	v := reflect.ValueOf(http.DefaultServeMux)
    36  
    37  	// Get lock.
    38  	muField := v.Elem().FieldByName("mu")
    39  	if !muField.IsValid() {
    40  		return errors.New("http.DefaultServeMux does not have a field called `mu`")
    41  	}
    42  	muPointer := unsafe.Pointer(muField.UnsafeAddr())
    43  	mu := (*sync.RWMutex)(muPointer)
    44  	(*mu).Lock()
    45  	defer (*mu).Unlock()
    46  
    47  	// Delete value of map.
    48  	mField := v.Elem().FieldByName("m")
    49  	if !mField.IsValid() {
    50  		return errors.New("http.DefaultServeMux does not have a field called `m`")
    51  	}
    52  	mPointer := unsafe.Pointer(mField.UnsafeAddr())
    53  	m := (*map[string]muxEntry)(mPointer)
    54  	for _, pattern := range patterns {
    55  		delete(*m, pattern)
    56  	}
    57  
    58  	// Delete value of muxEntry slice.
    59  	esField := v.Elem().FieldByName("es")
    60  	if !esField.IsValid() {
    61  		return errors.New("http.DefaultServeMux does not have a field called `es`")
    62  	}
    63  	esPointer := unsafe.Pointer(esField.UnsafeAddr())
    64  	es := (*[]muxEntry)(esPointer)
    65  	for _, pattern := range patterns {
    66  		// Removes muxEntry of the same pattern.
    67  		var j int
    68  		for _, muxEntry := range *es {
    69  			if muxEntry.pattern != pattern {
    70  				(*es)[j] = muxEntry
    71  				j++
    72  			}
    73  		}
    74  		*es = (*es)[:j]
    75  	}
    76  
    77  	// Modify hosts.
    78  	hostsField := v.Elem().FieldByName("hosts")
    79  	if !hostsField.IsValid() {
    80  		return errors.New("http.DefaultServeMux does not have a field called `hosts`")
    81  	}
    82  	hostsPointer := unsafe.Pointer(hostsField.UnsafeAddr())
    83  	hosts := (*bool)(hostsPointer)
    84  	*hosts = false
    85  	for _, v := range *m {
    86  		if v.pattern[0] != '/' {
    87  			*hosts = true
    88  		}
    89  	}
    90  
    91  	return nil
    92  }