github.com/oinume/lekcije@v0.0.0-20231017100347-5b4c5eb6ab24/backend/interface/http/middleware.go (about) 1 package http 2 3 import ( 4 "fmt" 5 "net/http" 6 "strings" 7 "time" 8 9 "github.com/google/uuid" 10 "github.com/jinzhu/gorm" 11 "github.com/morikuni/failure" 12 "github.com/rs/cors" 13 "go.uber.org/zap" 14 15 "github.com/oinume/lekcije/backend/context_data" 16 "github.com/oinume/lekcije/backend/domain/config" 17 "github.com/oinume/lekcije/backend/errors" 18 "github.com/oinume/lekcije/backend/model" 19 "github.com/oinume/lekcije/backend/usecase" 20 ) 21 22 var _ = fmt.Print 23 24 func panicHandler(errorRecorder *usecase.ErrorRecorder) func(http.Handler) http.Handler { 25 return func(h http.Handler) http.Handler { 26 fn := func(w http.ResponseWriter, r *http.Request) { 27 req := r 28 defer func() { 29 if r := recover(); r != nil { 30 var err error 31 switch errorType := r.(type) { 32 case string: 33 err = fmt.Errorf(errorType) 34 case error: 35 err = errorType 36 default: 37 err = fmt.Errorf("unknown error type: %v", errorType) 38 } 39 e := failure.Wrap(err, failure.Message("panic occurred")) 40 internalServerError(req.Context(), errorRecorder, w, e, 0) 41 return 42 } 43 }() 44 h.ServeHTTP(w, r) 45 } 46 return http.HandlerFunc(fn) 47 } 48 } 49 50 func accessLogger(logger *zap.Logger) func(http.Handler) http.Handler { 51 return func(h http.Handler) http.Handler { 52 fn := func(w http.ResponseWriter, r *http.Request) { 53 start := time.Now() 54 writerProxy := WrapWriter(w) 55 h.ServeHTTP(writerProxy, r) 56 if r.URL.String() == "/api/webhook/sendGrid" { // Omit access log for papertrail quota 57 return 58 } 59 func() { 60 end := time.Now() 61 status := writerProxy.Status() 62 if status == 0 { 63 status = http.StatusOK 64 } 65 trackingID := "" 66 if v, err := context_data.GetTrackingID(r.Context()); err == nil { 67 trackingID = v 68 } 69 70 // 180.76.15.26 - - [31/Jul/2016:13:18:07 +0000] "GET / HTTP/1.1" 200 612 "-" "Mozilla/5.0 (compatible; Baiduspider/2.0; +http://www.baidu.com/search/spider.html)" 71 logger.Info( 72 "access", 73 zap.String("method", r.Method), 74 zap.String("url", r.URL.String()), 75 zap.Int("status", status), 76 zap.Int("bytes", writerProxy.BytesWritten()), 77 zap.String("remoteAddr", getRemoteAddress(r)), 78 zap.String("userAgent", r.Header.Get("User-Agent")), 79 zap.String("referer", r.Referer()), 80 zap.Duration("elapsed", end.Sub(start)/time.Millisecond), 81 zap.String("trackingID", trackingID), 82 ) 83 }() 84 } 85 return http.HandlerFunc(fn) 86 } 87 } 88 89 func setLoggedInUser(db *gorm.DB) func(http.Handler) http.Handler { 90 return func(h http.Handler) http.Handler { 91 fn := func(w http.ResponseWriter, r *http.Request) { 92 ctx := r.Context() 93 if r.RequestURI == "/api/status" { 94 h.ServeHTTP(w, r) 95 return 96 } 97 cookie, err := r.Cookie(APITokenCookieName) 98 if err != nil { 99 h.ServeHTTP(w, r) 100 return 101 } 102 103 userService := model.NewUserService(db) 104 user, err := userService.FindLoggedInUser(cookie.Value) 105 if err != nil { 106 h.ServeHTTP(w, r) 107 return 108 } 109 c := context_data.SetLoggedInUser(ctx, user) 110 h.ServeHTTP(w, r.WithContext(c)) 111 } 112 return http.HandlerFunc(fn) 113 } 114 } 115 116 func setTrackingID(h http.Handler) http.Handler { 117 fn := func(w http.ResponseWriter, r *http.Request) { 118 ignoreURLs := []string{ 119 "/api/status", 120 "/robots.txt", 121 "/sitemap.xml", 122 } 123 for _, u := range ignoreURLs { 124 if r.RequestURI == u { 125 h.ServeHTTP(w, r) 126 return 127 } 128 } 129 130 cookie, err := r.Cookie(TrackingIDCookieName) 131 var trackingID string 132 if err == nil { 133 trackingID = cookie.Value 134 } else { 135 trackingID = uuid.New().String() 136 domain := strings.Replace(r.Host, "www.", "", 1) 137 domain = strings.Replace(domain, ":4000", "", 1) // TODO: local only 138 c := &http.Cookie{ 139 Name: TrackingIDCookieName, 140 Value: trackingID, 141 Path: "/", 142 Domain: domain, 143 Expires: time.Now().UTC().Add(time.Hour * 24 * 365 * 2), 144 HttpOnly: true, 145 } 146 http.SetCookie(w, c) 147 } 148 c := context_data.SetTrackingID(r.Context(), trackingID) 149 h.ServeHTTP(w, r.WithContext(c)) 150 } 151 return http.HandlerFunc(fn) 152 } 153 154 func setGAMeasurementEventValues(h http.Handler) http.Handler { 155 fn := func(w http.ResponseWriter, r *http.Request) { 156 c := context_data.SetGAMeasurementEvent( 157 r.Context(), 158 newGAMeasurementEventFromRequest(r), 159 ) 160 h.ServeHTTP(w, r.WithContext(c)) 161 } 162 return http.HandlerFunc(fn) 163 } 164 165 func loginRequiredFilter(db *gorm.DB, appLogger *zap.Logger, errorRecorder *usecase.ErrorRecorder) func(http.Handler) http.Handler { 166 return func(h http.Handler) http.Handler { 167 fn := func(w http.ResponseWriter, r *http.Request) { 168 ctx := r.Context() 169 if !strings.HasPrefix(r.RequestURI, "/me") { 170 h.ServeHTTP(w, r) 171 return 172 } 173 cookie, err := r.Cookie(APITokenCookieName) 174 if err != nil { 175 appLogger.Debug("Not logged in") 176 http.Redirect(w, r, config.WebURL(), http.StatusFound) 177 return 178 } 179 180 // TODO: Use context_data.MustLoggedInUser(ctx) 181 userService := model.NewUserService(db) 182 user, err := userService.FindLoggedInUser(cookie.Value) 183 if err != nil { 184 if errors.IsNotFound(err) { 185 appLogger.Debug("not logged in") 186 http.Redirect(w, r, config.WebURL(), http.StatusFound) 187 return 188 } 189 internalServerError(r.Context(), errorRecorder, w, err, 0) 190 return 191 } 192 appLogger.Debug("Logged in user", zap.String("name", user.Name)) 193 c := context_data.SetLoggedInUser(ctx, user) 194 h.ServeHTTP(w, r.WithContext(c)) 195 } 196 return http.HandlerFunc(fn) 197 } 198 } 199 200 func setCORS(h http.Handler) http.Handler { 201 origins := []string{} 202 if strings.HasPrefix(config.StaticURL(), "http") { 203 origins = append(origins, strings.TrimSuffix(config.StaticURL(), "/static")) 204 } 205 c := cors.New(cors.Options{ 206 AllowedOrigins: origins, 207 //Debug: true, 208 }) 209 fn := func(w http.ResponseWriter, r *http.Request) { 210 c.HandlerFunc(w, r) 211 h.ServeHTTP(w, r) 212 } 213 return http.HandlerFunc(fn) 214 } 215 216 func redirecter(h http.Handler) http.Handler { 217 fn := func(w http.ResponseWriter, r *http.Request) { 218 if r.Host == "lekcije.herokuapp.com" { 219 http.Redirect(w, r, config.WebURL()+r.RequestURI, http.StatusMovedPermanently) 220 return 221 } 222 h.ServeHTTP(w, r) 223 } 224 return http.HandlerFunc(fn) 225 } 226 227 func setAuthorizationContext(h http.Handler) http.Handler { 228 fn := func(w http.ResponseWriter, r *http.Request) { 229 auth, err := ParseAuthorizationHeader(r.Header.Get("authorization")) 230 if err != nil { 231 h.ServeHTTP(w, r) 232 return 233 } 234 r = r.WithContext(context_data.SetAPIToken(r.Context(), strings.TrimSpace(auth))) 235 h.ServeHTTP(w, r) 236 } 237 return http.HandlerFunc(fn) 238 } 239 240 func ParseAuthorizationHeader(header string) (string, error) { 241 // Authorization: Bearer <token> 242 auth := strings.Split(header, " ") 243 if len(auth) < 2 || strings.ToLower(auth[0]) != "bearer" { 244 return "", fmt.Errorf("header value is not valid") 245 } 246 return auth[1], nil 247 }