github.com/scottcagno/storage@v1.8.0/pkg/web/servemux.go (about) 1 package web 2 3 import ( 4 "fmt" 5 "github.com/scottcagno/storage/pkg/web/logging" 6 "log" 7 "mime" 8 "net/http" 9 "os" 10 "path/filepath" 11 "runtime/debug" 12 "sort" 13 "strings" 14 "sync" 15 "time" 16 ) 17 18 type muxEntry struct { 19 method string 20 pattern string 21 handler http.Handler 22 } 23 24 func (m muxEntry) String() string { 25 if m.method == http.MethodGet { 26 return fmt.Sprintf("[%s] <a href=\"%s\">%s</a>", m.method, m.pattern, m.pattern) 27 } 28 if m.method == http.MethodPost { 29 return fmt.Sprintf("[%s] %s", m.method, m.pattern) 30 } 31 if m.method == http.MethodPut { 32 return fmt.Sprintf("[%s] %s", m.method, m.pattern) 33 } 34 if m.method == http.MethodDelete { 35 return fmt.Sprintf("[%s] %s", m.method, m.pattern) 36 } 37 return fmt.Sprintf("[%s] %s", m.method, m.pattern) 38 } 39 40 func (s *ServeMux) Len() int { 41 return len(s.es) 42 } 43 44 func (s *ServeMux) Less(i, j int) bool { 45 return s.es[i].pattern < s.es[j].pattern 46 } 47 48 func (s *ServeMux) Swap(i, j int) { 49 s.es[j], s.es[i] = s.es[i], s.es[j] 50 } 51 52 func (s *ServeMux) Search(x string) int { 53 return sort.Search(len(s.es), func(i int) bool { 54 return s.es[i].pattern >= x 55 }) 56 } 57 58 var ( 59 defaultStaticPath = "web/static/" 60 61 DefaultMuxConfigMaxOpts = &MuxConfig{ 62 StaticPath: defaultStaticPath, 63 WithLogging: true, 64 StdOutLogger: logging.NewStdOutLogger(os.Stdout), 65 StdErrLogger: logging.NewStdErrLogger(os.Stderr), 66 } 67 68 defaultMuxConfigMinOpts = &MuxConfig{ 69 StaticPath: defaultStaticPath, 70 } 71 ) 72 73 type MuxConfig struct { 74 StaticPath string 75 WithLogging bool 76 StdOutLogger *log.Logger 77 StdErrLogger *log.Logger 78 } 79 80 func checkMuxConfig(conf *MuxConfig) *MuxConfig { 81 if conf == nil { 82 conf = defaultMuxConfigMinOpts 83 } 84 if conf.StaticPath == *new(string) { 85 conf.StaticPath = defaultStaticPath 86 } else { 87 conf.StaticPath = filepath.FromSlash(conf.StaticPath + string(filepath.Separator)) 88 } 89 if conf.WithLogging { 90 if conf.StdOutLogger == nil { 91 conf.StdOutLogger = logging.NewStdOutLogger(os.Stdout) 92 } 93 if conf.StdErrLogger == nil { 94 conf.StdErrLogger = logging.NewStdErrLogger(os.Stderr) 95 } 96 } 97 return conf 98 } 99 100 type ServeMux struct { 101 lock sync.Mutex 102 logger *logging.LevelLogger 103 em map[string]muxEntry 104 es []muxEntry 105 } 106 107 func NewServeMux(logger *logging.LevelLogger) *ServeMux { 108 mux := &ServeMux{ 109 em: make(map[string]muxEntry), 110 es: make([]muxEntry, 0), 111 } 112 if logger != nil { 113 mux.logger = logger 114 } 115 mux.Get("/favicon.ico", http.NotFoundHandler()) 116 mux.Get("/info", mux.info()) 117 return mux 118 } 119 120 func (s *ServeMux) Handle(method string, pattern string, handler http.Handler) { 121 s.lock.Lock() 122 defer s.lock.Unlock() 123 if pattern == "" { 124 panic("http: invalid pattern") 125 } 126 if handler == nil { 127 panic("http: nil handler") 128 } 129 if _, exist := s.em[pattern]; exist { 130 panic("http: multiple registrations for " + pattern) 131 } 132 entry := muxEntry{ 133 method: method, 134 pattern: pattern, 135 handler: handler, 136 } 137 s.em[pattern] = entry 138 if pattern[len(pattern)-1] == '/' { 139 s.es = appendSorted(s.es, entry) 140 } 141 //s.routes.Put(entry) 142 } 143 144 func appendSorted(es []muxEntry, e muxEntry) []muxEntry { 145 n := len(es) 146 i := sort.Search(n, func(i int) bool { 147 return len(es[i].pattern) < len(e.pattern) 148 }) 149 if i == n { 150 return append(es, e) 151 } 152 // we now know that i points at where we want to insert 153 es = append(es, muxEntry{}) // try to grow the slice in place, any entry works. 154 copy(es[i+1:], es[i:]) // Move shorter entries down 155 es[i] = e 156 return es 157 } 158 159 func (s *ServeMux) HandleFunc(method, pattern string, handler func(http.ResponseWriter, *http.Request)) { 160 if handler == nil { 161 panic("http: nil handler") 162 } 163 s.Handle(method, pattern, http.HandlerFunc(handler)) 164 } 165 166 func (s *ServeMux) Forward(oldpattern string, newpattern string) { 167 s.Handle(http.MethodGet, oldpattern, http.RedirectHandler(newpattern, http.StatusTemporaryRedirect)) 168 } 169 170 func (s *ServeMux) Get(pattern string, handler http.Handler) { 171 s.Handle(http.MethodGet, pattern, handler) 172 } 173 174 func (s *ServeMux) Post(pattern string, handler http.Handler) { 175 s.Handle(http.MethodPost, pattern, handler) 176 } 177 178 func (s *ServeMux) Put(pattern string, handler http.Handler) { 179 s.Handle(http.MethodPut, pattern, handler) 180 } 181 182 func (s *ServeMux) Delete(pattern string, handler http.Handler) { 183 s.Handle(http.MethodDelete, pattern, handler) 184 } 185 186 func (s *ServeMux) Static(pattern string, path string) { 187 staticHandler := http.StripPrefix(pattern, http.FileServer(http.Dir(path))) 188 s.Handle(http.MethodGet, pattern, staticHandler) 189 } 190 191 func (s *ServeMux) getEntries() []string { 192 s.lock.Lock() 193 defer s.lock.Unlock() 194 var entries []string 195 for _, entry := range s.em { 196 entries = append(entries, fmt.Sprintf("%s %s\n", entry.method, entry.pattern)) 197 } 198 return entries 199 } 200 201 func (s *ServeMux) matchEntry(path string) (string, string, http.Handler) { 202 e, ok := s.em[path] 203 if ok { 204 return e.method, e.pattern, e.handler 205 } 206 for _, e = range s.es { 207 if strings.HasPrefix(path, e.pattern) { 208 return e.method, e.pattern, e.handler 209 } 210 } 211 return "", "", nil 212 } 213 214 func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { 215 // check for post request in order to validate via the referer 216 if r.Method == "POST" && !strings.Contains(r.Referer(), r.Host) { 217 // possibly add origin check in there too... 218 // 219 code := http.StatusForbidden // invalid request redirect to 403 220 http.Error(w, http.StatusText(code), code) 221 return 222 } 223 // look for a matching entry 224 m, _, h := s.matchEntry(r.URL.Path) 225 if m != r.Method || h == nil { 226 // otherwise, return not found 227 h = http.NotFoundHandler() 228 } 229 // if logging is configured, then log, otherwise skip 230 if s.logger != nil { 231 h = s.requestLogger(h) 232 } 233 // serve 234 h.ServeHTTP(w, r) 235 } 236 237 func (s *ServeMux) info() http.Handler { 238 fn := func(w http.ResponseWriter, r *http.Request) { 239 var data []string 240 data = append(data, fmt.Sprintf("<h3>Registered Routes (%d)</h3>", len(s.em))) 241 for _, entry := range s.em { 242 data = append(data, entry.String()) 243 } 244 sort.Slice(data, func(i, j int) bool { 245 return data[i] < data[j] 246 }) 247 s.ContentType(w, ".html") 248 _, err := fmt.Fprintf(w, strings.Join(data, "<br>")) 249 if err != nil { 250 code := http.StatusInternalServerError 251 http.Error(w, http.StatusText(code), code) 252 return 253 } 254 return 255 } 256 return http.HandlerFunc(fn) 257 } 258 259 func (s *ServeMux) ContentType(w http.ResponseWriter, content string) { 260 s.lock.Lock() 261 defer s.lock.Unlock() 262 ct := mime.TypeByExtension(content) 263 if ct == "" { 264 s.logger.Error("Error, incompatible content type!\n") 265 return 266 } 267 w.Header().Set("Content-Type", ct) 268 return 269 } 270 271 type responseData struct { 272 status int 273 size int 274 } 275 276 type loggingResponseWriter struct { 277 http.ResponseWriter 278 data *responseData 279 } 280 281 func (w *loggingResponseWriter) Header() http.Header { 282 return w.ResponseWriter.Header() 283 } 284 285 func (w *loggingResponseWriter) Write(b []byte) (int, error) { 286 size, err := w.ResponseWriter.Write(b) 287 w.data.size += size 288 return size, err 289 } 290 291 func (w *loggingResponseWriter) WriteHeader(statusCode int) { 292 w.ResponseWriter.WriteHeader(statusCode) 293 w.data.status = statusCode 294 } 295 296 func (s *ServeMux) requestLogger(next http.Handler) http.Handler { 297 fn := func(w http.ResponseWriter, r *http.Request) { 298 defer func() { 299 if err := recover(); err != nil { 300 w.WriteHeader(http.StatusInternalServerError) 301 s.logger.Error("err: %v, trace: %s\n", err, debug.Stack()) 302 } 303 }() 304 lrw := loggingResponseWriter{ 305 ResponseWriter: w, 306 data: &responseData{ 307 status: 200, 308 size: 0, 309 }, 310 } 311 next.ServeHTTP(&lrw, r) 312 if 400 <= lrw.data.status && lrw.data.status <= 599 { 313 str, args := logRequest(lrw.data.status, r) 314 s.logger.Error(str, args...) 315 return 316 } 317 str, args := logRequest(lrw.data.status, r) 318 s.logger.Info(str, args...) 319 return 320 } 321 return http.HandlerFunc(fn) 322 } 323 324 func logRequest(code int, r *http.Request) (string, []interface{}) { 325 format, values := "# %s - - [%s] \"%s %s %s\" %d %d\n", []interface{}{ 326 r.RemoteAddr, 327 time.Now().Format(time.RFC1123Z), 328 r.Method, 329 r.URL.EscapedPath(), 330 r.Proto, 331 code, 332 r.ContentLength, 333 } 334 return format, values 335 }