github.com/scottcagno/storage@v1.8.0/pkg/web/servemux.go (about)

     1  package web
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/scottcagno/storage/pkg/web/logging"
     6  	"log"
     7  	"mime"
     8  	"net/http"
     9  	"os"
    10  	"path/filepath"
    11  	"runtime/debug"
    12  	"sort"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  )
    17  
    18  type muxEntry struct {
    19  	method  string
    20  	pattern string
    21  	handler http.Handler
    22  }
    23  
    24  func (m muxEntry) String() string {
    25  	if m.method == http.MethodGet {
    26  		return fmt.Sprintf("[%s]&nbsp;&nbsp;&nbsp;&nbsp;<a href=\"%s\">%s</a>", m.method, m.pattern, m.pattern)
    27  	}
    28  	if m.method == http.MethodPost {
    29  		return fmt.Sprintf("[%s]&nbsp;&nbsp;&nbsp;%s", m.method, m.pattern)
    30  	}
    31  	if m.method == http.MethodPut {
    32  		return fmt.Sprintf("[%s]&nbsp;&nbsp;&nbsp;&nbsp;%s", m.method, m.pattern)
    33  	}
    34  	if m.method == http.MethodDelete {
    35  		return fmt.Sprintf("[%s]&nbsp;%s", m.method, m.pattern)
    36  	}
    37  	return fmt.Sprintf("[%s]&nbsp;%s", m.method, m.pattern)
    38  }
    39  
    40  func (s *ServeMux) Len() int {
    41  	return len(s.es)
    42  }
    43  
    44  func (s *ServeMux) Less(i, j int) bool {
    45  	return s.es[i].pattern < s.es[j].pattern
    46  }
    47  
    48  func (s *ServeMux) Swap(i, j int) {
    49  	s.es[j], s.es[i] = s.es[i], s.es[j]
    50  }
    51  
    52  func (s *ServeMux) Search(x string) int {
    53  	return sort.Search(len(s.es), func(i int) bool {
    54  		return s.es[i].pattern >= x
    55  	})
    56  }
    57  
    58  var (
    59  	defaultStaticPath = "web/static/"
    60  
    61  	DefaultMuxConfigMaxOpts = &MuxConfig{
    62  		StaticPath:   defaultStaticPath,
    63  		WithLogging:  true,
    64  		StdOutLogger: logging.NewStdOutLogger(os.Stdout),
    65  		StdErrLogger: logging.NewStdErrLogger(os.Stderr),
    66  	}
    67  
    68  	defaultMuxConfigMinOpts = &MuxConfig{
    69  		StaticPath: defaultStaticPath,
    70  	}
    71  )
    72  
    73  type MuxConfig struct {
    74  	StaticPath   string
    75  	WithLogging  bool
    76  	StdOutLogger *log.Logger
    77  	StdErrLogger *log.Logger
    78  }
    79  
    80  func checkMuxConfig(conf *MuxConfig) *MuxConfig {
    81  	if conf == nil {
    82  		conf = defaultMuxConfigMinOpts
    83  	}
    84  	if conf.StaticPath == *new(string) {
    85  		conf.StaticPath = defaultStaticPath
    86  	} else {
    87  		conf.StaticPath = filepath.FromSlash(conf.StaticPath + string(filepath.Separator))
    88  	}
    89  	if conf.WithLogging {
    90  		if conf.StdOutLogger == nil {
    91  			conf.StdOutLogger = logging.NewStdOutLogger(os.Stdout)
    92  		}
    93  		if conf.StdErrLogger == nil {
    94  			conf.StdErrLogger = logging.NewStdErrLogger(os.Stderr)
    95  		}
    96  	}
    97  	return conf
    98  }
    99  
   100  type ServeMux struct {
   101  	lock   sync.Mutex
   102  	logger *logging.LevelLogger
   103  	em     map[string]muxEntry
   104  	es     []muxEntry
   105  }
   106  
   107  func NewServeMux(logger *logging.LevelLogger) *ServeMux {
   108  	mux := &ServeMux{
   109  		em: make(map[string]muxEntry),
   110  		es: make([]muxEntry, 0),
   111  	}
   112  	if logger != nil {
   113  		mux.logger = logger
   114  	}
   115  	mux.Get("/favicon.ico", http.NotFoundHandler())
   116  	mux.Get("/info", mux.info())
   117  	return mux
   118  }
   119  
   120  func (s *ServeMux) Handle(method string, pattern string, handler http.Handler) {
   121  	s.lock.Lock()
   122  	defer s.lock.Unlock()
   123  	if pattern == "" {
   124  		panic("http: invalid pattern")
   125  	}
   126  	if handler == nil {
   127  		panic("http: nil handler")
   128  	}
   129  	if _, exist := s.em[pattern]; exist {
   130  		panic("http: multiple registrations for " + pattern)
   131  	}
   132  	entry := muxEntry{
   133  		method:  method,
   134  		pattern: pattern,
   135  		handler: handler,
   136  	}
   137  	s.em[pattern] = entry
   138  	if pattern[len(pattern)-1] == '/' {
   139  		s.es = appendSorted(s.es, entry)
   140  	}
   141  	//s.routes.Put(entry)
   142  }
   143  
   144  func appendSorted(es []muxEntry, e muxEntry) []muxEntry {
   145  	n := len(es)
   146  	i := sort.Search(n, func(i int) bool {
   147  		return len(es[i].pattern) < len(e.pattern)
   148  	})
   149  	if i == n {
   150  		return append(es, e)
   151  	}
   152  	// we now know that i points at where we want to insert
   153  	es = append(es, muxEntry{}) // try to grow the slice in place, any entry works.
   154  	copy(es[i+1:], es[i:])      // Move shorter entries down
   155  	es[i] = e
   156  	return es
   157  }
   158  
   159  func (s *ServeMux) HandleFunc(method, pattern string, handler func(http.ResponseWriter, *http.Request)) {
   160  	if handler == nil {
   161  		panic("http: nil handler")
   162  	}
   163  	s.Handle(method, pattern, http.HandlerFunc(handler))
   164  }
   165  
   166  func (s *ServeMux) Forward(oldpattern string, newpattern string) {
   167  	s.Handle(http.MethodGet, oldpattern, http.RedirectHandler(newpattern, http.StatusTemporaryRedirect))
   168  }
   169  
   170  func (s *ServeMux) Get(pattern string, handler http.Handler) {
   171  	s.Handle(http.MethodGet, pattern, handler)
   172  }
   173  
   174  func (s *ServeMux) Post(pattern string, handler http.Handler) {
   175  	s.Handle(http.MethodPost, pattern, handler)
   176  }
   177  
   178  func (s *ServeMux) Put(pattern string, handler http.Handler) {
   179  	s.Handle(http.MethodPut, pattern, handler)
   180  }
   181  
   182  func (s *ServeMux) Delete(pattern string, handler http.Handler) {
   183  	s.Handle(http.MethodDelete, pattern, handler)
   184  }
   185  
   186  func (s *ServeMux) Static(pattern string, path string) {
   187  	staticHandler := http.StripPrefix(pattern, http.FileServer(http.Dir(path)))
   188  	s.Handle(http.MethodGet, pattern, staticHandler)
   189  }
   190  
   191  func (s *ServeMux) getEntries() []string {
   192  	s.lock.Lock()
   193  	defer s.lock.Unlock()
   194  	var entries []string
   195  	for _, entry := range s.em {
   196  		entries = append(entries, fmt.Sprintf("%s %s\n", entry.method, entry.pattern))
   197  	}
   198  	return entries
   199  }
   200  
   201  func (s *ServeMux) matchEntry(path string) (string, string, http.Handler) {
   202  	e, ok := s.em[path]
   203  	if ok {
   204  		return e.method, e.pattern, e.handler
   205  	}
   206  	for _, e = range s.es {
   207  		if strings.HasPrefix(path, e.pattern) {
   208  			return e.method, e.pattern, e.handler
   209  		}
   210  	}
   211  	return "", "", nil
   212  }
   213  
   214  func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   215  	// check for post request in order to validate via the referer
   216  	if r.Method == "POST" && !strings.Contains(r.Referer(), r.Host) {
   217  		// possibly add origin check in there too...
   218  		//
   219  		code := http.StatusForbidden // invalid request redirect to 403
   220  		http.Error(w, http.StatusText(code), code)
   221  		return
   222  	}
   223  	// look for a matching entry
   224  	m, _, h := s.matchEntry(r.URL.Path)
   225  	if m != r.Method || h == nil {
   226  		// otherwise, return not found
   227  		h = http.NotFoundHandler()
   228  	}
   229  	// if logging is configured, then log, otherwise skip
   230  	if s.logger != nil {
   231  		h = s.requestLogger(h)
   232  	}
   233  	// serve
   234  	h.ServeHTTP(w, r)
   235  }
   236  
   237  func (s *ServeMux) info() http.Handler {
   238  	fn := func(w http.ResponseWriter, r *http.Request) {
   239  		var data []string
   240  		data = append(data, fmt.Sprintf("<h3>Registered Routes (%d)</h3>", len(s.em)))
   241  		for _, entry := range s.em {
   242  			data = append(data, entry.String())
   243  		}
   244  		sort.Slice(data, func(i, j int) bool {
   245  			return data[i] < data[j]
   246  		})
   247  		s.ContentType(w, ".html")
   248  		_, err := fmt.Fprintf(w, strings.Join(data, "<br>"))
   249  		if err != nil {
   250  			code := http.StatusInternalServerError
   251  			http.Error(w, http.StatusText(code), code)
   252  			return
   253  		}
   254  		return
   255  	}
   256  	return http.HandlerFunc(fn)
   257  }
   258  
   259  func (s *ServeMux) ContentType(w http.ResponseWriter, content string) {
   260  	s.lock.Lock()
   261  	defer s.lock.Unlock()
   262  	ct := mime.TypeByExtension(content)
   263  	if ct == "" {
   264  		s.logger.Error("Error, incompatible content type!\n")
   265  		return
   266  	}
   267  	w.Header().Set("Content-Type", ct)
   268  	return
   269  }
   270  
   271  type responseData struct {
   272  	status int
   273  	size   int
   274  }
   275  
   276  type loggingResponseWriter struct {
   277  	http.ResponseWriter
   278  	data *responseData
   279  }
   280  
   281  func (w *loggingResponseWriter) Header() http.Header {
   282  	return w.ResponseWriter.Header()
   283  }
   284  
   285  func (w *loggingResponseWriter) Write(b []byte) (int, error) {
   286  	size, err := w.ResponseWriter.Write(b)
   287  	w.data.size += size
   288  	return size, err
   289  }
   290  
   291  func (w *loggingResponseWriter) WriteHeader(statusCode int) {
   292  	w.ResponseWriter.WriteHeader(statusCode)
   293  	w.data.status = statusCode
   294  }
   295  
   296  func (s *ServeMux) requestLogger(next http.Handler) http.Handler {
   297  	fn := func(w http.ResponseWriter, r *http.Request) {
   298  		defer func() {
   299  			if err := recover(); err != nil {
   300  				w.WriteHeader(http.StatusInternalServerError)
   301  				s.logger.Error("err: %v, trace: %s\n", err, debug.Stack())
   302  			}
   303  		}()
   304  		lrw := loggingResponseWriter{
   305  			ResponseWriter: w,
   306  			data: &responseData{
   307  				status: 200,
   308  				size:   0,
   309  			},
   310  		}
   311  		next.ServeHTTP(&lrw, r)
   312  		if 400 <= lrw.data.status && lrw.data.status <= 599 {
   313  			str, args := logRequest(lrw.data.status, r)
   314  			s.logger.Error(str, args...)
   315  			return
   316  		}
   317  		str, args := logRequest(lrw.data.status, r)
   318  		s.logger.Info(str, args...)
   319  		return
   320  	}
   321  	return http.HandlerFunc(fn)
   322  }
   323  
   324  func logRequest(code int, r *http.Request) (string, []interface{}) {
   325  	format, values := "# %s - - [%s] \"%s %s %s\" %d %d\n", []interface{}{
   326  		r.RemoteAddr,
   327  		time.Now().Format(time.RFC1123Z),
   328  		r.Method,
   329  		r.URL.EscapedPath(),
   330  		r.Proto,
   331  		code,
   332  		r.ContentLength,
   333  	}
   334  	return format, values
   335  }