
     1  package diagnostic
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"strconv"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    14  	""
    15  	""
    16  	""
    17  )
    19  // HTTPHandlerFunc TODO
    20  type HTTPHandlerFunc func(interface{}, http.ResponseWriter, *http.Request)
    22  type httpHandlerCustom struct {
    23  	ctx interface{}
    24  	F   func(interface{}, http.ResponseWriter, *http.Request)
    25  }
    27  // ServeHTTP TODO
    28  func (h httpHandlerCustom) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    29  	h.F(h.ctx, w, r)
    30  }
    32  var diagPaths2Func = map[string]HTTPHandlerFunc{
    33  	"/":          notImplemented,
    34  	"/help":      help,
    35  	"/ready":     ready,
    36  	"/stackdump": stackTrace,
    37  }
    39  // Server when the debug is enabled exposes a
    40  // This data structure is protected by the Agent mutex so does not require and additional mutex here
    41  type Server struct {
    42  	enable            int32
    43  	srv               *http.Server
    44  	port              int
    45  	mux               *http.ServeMux
    46  	registeredHanders map[string]bool
    47  	sync.Mutex
    48  }
    50  // New creates a new diagnostic server
    51  func New() *Server {
    52  	return &Server{
    53  		registeredHanders: make(map[string]bool),
    54  	}
    55  }
    57  // Init initialize the mux for the http handling and register the base hooks
    58  func (s *Server) Init() {
    59  	s.mux = http.NewServeMux()
    61  	// Register local handlers
    62  	s.RegisterHandler(s, diagPaths2Func)
    63  }
    65  // RegisterHandler allows to register new handlers to the mux and to a specific path
    66  func (s *Server) RegisterHandler(ctx interface{}, hdlrs map[string]HTTPHandlerFunc) {
    67  	s.Lock()
    68  	defer s.Unlock()
    69  	for path, fun := range hdlrs {
    70  		if _, ok := s.registeredHanders[path]; ok {
    71  			continue
    72  		}
    73  		s.mux.Handle(path, httpHandlerCustom{ctx, fun})
    74  		s.registeredHanders[path] = true
    75  	}
    76  }
    78  // ServeHTTP this is the method called bu the ListenAndServe, and is needed to allow us to
    79  // use our custom mux
    80  func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    81  	s.mux.ServeHTTP(w, r)
    82  }
    84  // EnableDiagnostic opens a TCP socket to debug the passed network DB
    85  func (s *Server) EnableDiagnostic(ip string, port int) {
    86  	s.Lock()
    87  	defer s.Unlock()
    89  	s.port = port
    91  	if s.enable == 1 {
    92  		log.G(context.TODO()).Info("The server is already up and running")
    93  		return
    94  	}
    96  	log.G(context.TODO()).Infof("Starting the diagnostic server listening on %d for commands", port)
    97  	srv := &http.Server{
    98  		Addr:              net.JoinHostPort(ip, strconv.Itoa(port)),
    99  		Handler:           s,
   100  		ReadHeaderTimeout: 5 * time.Minute, // "G112: Potential Slowloris Attack (gosec)"; not a real concern for our use, so setting a long timeout.
   101  	}
   102  	s.srv = srv
   103  	s.enable = 1
   104  	go func(n *Server) {
   105  		// Ignore ErrServerClosed that is returned on the Shutdown call
   106  		if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
   107  			log.G(context.TODO()).Errorf("ListenAndServe error: %s", err)
   108  			atomic.SwapInt32(&n.enable, 0)
   109  		}
   110  	}(s)
   111  }
   113  // DisableDiagnostic stop the dubug and closes the tcp socket
   114  func (s *Server) DisableDiagnostic() {
   115  	s.Lock()
   116  	defer s.Unlock()
   118  	s.srv.Shutdown(context.Background()) //nolint:errcheck
   119  	s.srv = nil
   120  	s.enable = 0
   121  	log.G(context.TODO()).Info("Disabling the diagnostic server")
   122  }
   124  // IsDiagnosticEnabled returns true when the debug is enabled
   125  func (s *Server) IsDiagnosticEnabled() bool {
   126  	s.Lock()
   127  	defer s.Unlock()
   128  	return s.enable == 1
   129  }
   131  func notImplemented(ctx interface{}, w http.ResponseWriter, r *http.Request) {
   132  	_ = r.ParseForm()
   133  	_, jsonOutput := ParseHTTPFormOptions(r)
   134  	rsp := WrongCommand("not implemented", fmt.Sprintf("URL path: %s no method implemented check /help\n", r.URL.Path))
   136  	// audit logs
   137  	log.G(context.TODO()).WithFields(log.Fields{
   138  		"component": "diagnostic",
   139  		"remoteIP":  r.RemoteAddr,
   140  		"method":    caller.Name(0),
   141  		"url":       r.URL.String(),
   142  	}).Info("command not implemented done")
   144  	_, _ = HTTPReply(w, rsp, jsonOutput)
   145  }
   147  func help(ctx interface{}, w http.ResponseWriter, r *http.Request) {
   148  	_ = r.ParseForm()
   149  	_, jsonOutput := ParseHTTPFormOptions(r)
   151  	// audit logs
   152  	log.G(context.TODO()).WithFields(log.Fields{
   153  		"component": "diagnostic",
   154  		"remoteIP":  r.RemoteAddr,
   155  		"method":    caller.Name(0),
   156  		"url":       r.URL.String(),
   157  	}).Info("help done")
   159  	n, ok := ctx.(*Server)
   160  	var result string
   161  	if ok {
   162  		for path := range n.registeredHanders {
   163  			result += fmt.Sprintf("%s\n", path)
   164  		}
   165  		_, _ = HTTPReply(w, CommandSucceed(&StringCmd{Info: result}), jsonOutput)
   166  	}
   167  }
   169  func ready(ctx interface{}, w http.ResponseWriter, r *http.Request) {
   170  	_ = r.ParseForm()
   171  	_, jsonOutput := ParseHTTPFormOptions(r)
   173  	// audit logs
   174  	log.G(context.TODO()).WithFields(log.Fields{
   175  		"component": "diagnostic",
   176  		"remoteIP":  r.RemoteAddr,
   177  		"method":    caller.Name(0),
   178  		"url":       r.URL.String(),
   179  	}).Info("ready done")
   180  	_, _ = HTTPReply(w, CommandSucceed(&StringCmd{Info: "OK"}), jsonOutput)
   181  }
   183  func stackTrace(ctx interface{}, w http.ResponseWriter, r *http.Request) {
   184  	_ = r.ParseForm()
   185  	_, jsonOutput := ParseHTTPFormOptions(r)
   187  	// audit logs
   188  	logger := log.G(context.TODO()).WithFields(log.Fields{"component": "diagnostic", "remoteIP": r.RemoteAddr, "method": caller.Name(0), "url": r.URL.String()})
   189  	logger.Info("stack trace")
   191  	path, err := stack.DumpToFile("/tmp/")
   192  	if err != nil {
   193  		logger.WithError(err).Error("failed to write goroutines dump")
   194  		_, _ = HTTPReply(w, FailCommand(err), jsonOutput)
   195  	} else {
   196  		logger.Info("stack trace done")
   197  		_, _ = HTTPReply(w, CommandSucceed(&StringCmd{Info: "goroutine stacks written to " + path}), jsonOutput)
   198  	}
   199  }
   201  // DebugHTTPForm helper to print the form url parameters
   202  func DebugHTTPForm(r *http.Request) {
   203  	for k, v := range r.Form {
   204  		log.G(context.TODO()).Debugf("Form[%q] = %q\n", k, v)
   205  	}
   206  }
   208  // JSONOutput contains details on JSON output printing
   209  type JSONOutput struct {
   210  	enable      bool
   211  	prettyPrint bool
   212  }
   214  // ParseHTTPFormOptions easily parse the JSON printing options
   215  func ParseHTTPFormOptions(r *http.Request) (bool, *JSONOutput) {
   216  	_, unsafe := r.Form["unsafe"]
   217  	v, enableJSON := r.Form["json"]
   218  	var pretty bool
   219  	if len(v) > 0 {
   220  		pretty = v[0] == "pretty"
   221  	}
   222  	return unsafe, &JSONOutput{enable: enableJSON, prettyPrint: pretty}
   223  }
   225  // HTTPReply helper function that takes care of sending the message out
   226  func HTTPReply(w http.ResponseWriter, r *HTTPResult, j *JSONOutput) (int, error) {
   227  	var response []byte
   228  	if j.enable {
   229  		w.Header().Set("Content-Type", "application/json")
   230  		var err error
   231  		if j.prettyPrint {
   232  			response, err = json.MarshalIndent(r, "", "  ")
   233  			if err != nil {
   234  				response, _ = json.MarshalIndent(FailCommand(err), "", "  ")
   235  			}
   236  		} else {
   237  			response, err = json.Marshal(r)
   238  			if err != nil {
   239  				response, _ = json.Marshal(FailCommand(err))
   240  			}
   241  		}
   242  	} else {
   243  		response = []byte(r.String())
   244  	}
   245  	return fmt.Fprint(w, string(response))
   246  }