github.com/oinume/lekcije@v0.0.0-20231017100347-5b4c5eb6ab24/backend/interface/http/middleware.go (about)

     1  package http
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/google/uuid"
    10  	"github.com/jinzhu/gorm"
    11  	"github.com/morikuni/failure"
    12  	"github.com/rs/cors"
    13  	"go.uber.org/zap"
    14  
    15  	"github.com/oinume/lekcije/backend/context_data"
    16  	"github.com/oinume/lekcije/backend/domain/config"
    17  	"github.com/oinume/lekcije/backend/errors"
    18  	"github.com/oinume/lekcije/backend/model"
    19  	"github.com/oinume/lekcije/backend/usecase"
    20  )
    21  
    22  var _ = fmt.Print
    23  
    24  func panicHandler(errorRecorder *usecase.ErrorRecorder) func(http.Handler) http.Handler {
    25  	return func(h http.Handler) http.Handler {
    26  		fn := func(w http.ResponseWriter, r *http.Request) {
    27  			req := r
    28  			defer func() {
    29  				if r := recover(); r != nil {
    30  					var err error
    31  					switch errorType := r.(type) {
    32  					case string:
    33  						err = fmt.Errorf(errorType)
    34  					case error:
    35  						err = errorType
    36  					default:
    37  						err = fmt.Errorf("unknown error type: %v", errorType)
    38  					}
    39  					e := failure.Wrap(err, failure.Message("panic occurred"))
    40  					internalServerError(req.Context(), errorRecorder, w, e, 0)
    41  					return
    42  				}
    43  			}()
    44  			h.ServeHTTP(w, r)
    45  		}
    46  		return http.HandlerFunc(fn)
    47  	}
    48  }
    49  
    50  func accessLogger(logger *zap.Logger) func(http.Handler) http.Handler {
    51  	return func(h http.Handler) http.Handler {
    52  		fn := func(w http.ResponseWriter, r *http.Request) {
    53  			start := time.Now()
    54  			writerProxy := WrapWriter(w)
    55  			h.ServeHTTP(writerProxy, r)
    56  			if r.URL.String() == "/api/webhook/sendGrid" { // Omit access log for papertrail quota
    57  				return
    58  			}
    59  			func() {
    60  				end := time.Now()
    61  				status := writerProxy.Status()
    62  				if status == 0 {
    63  					status = http.StatusOK
    64  				}
    65  				trackingID := ""
    66  				if v, err := context_data.GetTrackingID(r.Context()); err == nil {
    67  					trackingID = v
    68  				}
    69  
    70  				// 180.76.15.26 - - [31/Jul/2016:13:18:07 +0000] "GET / HTTP/1.1" 200 612 "-" "Mozilla/5.0 (compatible; Baiduspider/2.0; +http://www.baidu.com/search/spider.html)"
    71  				logger.Info(
    72  					"access",
    73  					zap.String("method", r.Method),
    74  					zap.String("url", r.URL.String()),
    75  					zap.Int("status", status),
    76  					zap.Int("bytes", writerProxy.BytesWritten()),
    77  					zap.String("remoteAddr", getRemoteAddress(r)),
    78  					zap.String("userAgent", r.Header.Get("User-Agent")),
    79  					zap.String("referer", r.Referer()),
    80  					zap.Duration("elapsed", end.Sub(start)/time.Millisecond),
    81  					zap.String("trackingID", trackingID),
    82  				)
    83  			}()
    84  		}
    85  		return http.HandlerFunc(fn)
    86  	}
    87  }
    88  
    89  func setLoggedInUser(db *gorm.DB) func(http.Handler) http.Handler {
    90  	return func(h http.Handler) http.Handler {
    91  		fn := func(w http.ResponseWriter, r *http.Request) {
    92  			ctx := r.Context()
    93  			if r.RequestURI == "/api/status" {
    94  				h.ServeHTTP(w, r)
    95  				return
    96  			}
    97  			cookie, err := r.Cookie(APITokenCookieName)
    98  			if err != nil {
    99  				h.ServeHTTP(w, r)
   100  				return
   101  			}
   102  
   103  			userService := model.NewUserService(db)
   104  			user, err := userService.FindLoggedInUser(cookie.Value)
   105  			if err != nil {
   106  				h.ServeHTTP(w, r)
   107  				return
   108  			}
   109  			c := context_data.SetLoggedInUser(ctx, user)
   110  			h.ServeHTTP(w, r.WithContext(c))
   111  		}
   112  		return http.HandlerFunc(fn)
   113  	}
   114  }
   115  
   116  func setTrackingID(h http.Handler) http.Handler {
   117  	fn := func(w http.ResponseWriter, r *http.Request) {
   118  		ignoreURLs := []string{
   119  			"/api/status",
   120  			"/robots.txt",
   121  			"/sitemap.xml",
   122  		}
   123  		for _, u := range ignoreURLs {
   124  			if r.RequestURI == u {
   125  				h.ServeHTTP(w, r)
   126  				return
   127  			}
   128  		}
   129  
   130  		cookie, err := r.Cookie(TrackingIDCookieName)
   131  		var trackingID string
   132  		if err == nil {
   133  			trackingID = cookie.Value
   134  		} else {
   135  			trackingID = uuid.New().String()
   136  			domain := strings.Replace(r.Host, "www.", "", 1)
   137  			domain = strings.Replace(domain, ":4000", "", 1) // TODO: local only
   138  			c := &http.Cookie{
   139  				Name:     TrackingIDCookieName,
   140  				Value:    trackingID,
   141  				Path:     "/",
   142  				Domain:   domain,
   143  				Expires:  time.Now().UTC().Add(time.Hour * 24 * 365 * 2),
   144  				HttpOnly: true,
   145  			}
   146  			http.SetCookie(w, c)
   147  		}
   148  		c := context_data.SetTrackingID(r.Context(), trackingID)
   149  		h.ServeHTTP(w, r.WithContext(c))
   150  	}
   151  	return http.HandlerFunc(fn)
   152  }
   153  
   154  func setGAMeasurementEventValues(h http.Handler) http.Handler {
   155  	fn := func(w http.ResponseWriter, r *http.Request) {
   156  		c := context_data.SetGAMeasurementEvent(
   157  			r.Context(),
   158  			newGAMeasurementEventFromRequest(r),
   159  		)
   160  		h.ServeHTTP(w, r.WithContext(c))
   161  	}
   162  	return http.HandlerFunc(fn)
   163  }
   164  
   165  func loginRequiredFilter(db *gorm.DB, appLogger *zap.Logger, errorRecorder *usecase.ErrorRecorder) func(http.Handler) http.Handler {
   166  	return func(h http.Handler) http.Handler {
   167  		fn := func(w http.ResponseWriter, r *http.Request) {
   168  			ctx := r.Context()
   169  			if !strings.HasPrefix(r.RequestURI, "/me") {
   170  				h.ServeHTTP(w, r)
   171  				return
   172  			}
   173  			cookie, err := r.Cookie(APITokenCookieName)
   174  			if err != nil {
   175  				appLogger.Debug("Not logged in")
   176  				http.Redirect(w, r, config.WebURL(), http.StatusFound)
   177  				return
   178  			}
   179  
   180  			// TODO: Use context_data.MustLoggedInUser(ctx)
   181  			userService := model.NewUserService(db)
   182  			user, err := userService.FindLoggedInUser(cookie.Value)
   183  			if err != nil {
   184  				if errors.IsNotFound(err) {
   185  					appLogger.Debug("not logged in")
   186  					http.Redirect(w, r, config.WebURL(), http.StatusFound)
   187  					return
   188  				}
   189  				internalServerError(r.Context(), errorRecorder, w, err, 0)
   190  				return
   191  			}
   192  			appLogger.Debug("Logged in user", zap.String("name", user.Name))
   193  			c := context_data.SetLoggedInUser(ctx, user)
   194  			h.ServeHTTP(w, r.WithContext(c))
   195  		}
   196  		return http.HandlerFunc(fn)
   197  	}
   198  }
   199  
   200  func setCORS(h http.Handler) http.Handler {
   201  	origins := []string{}
   202  	if strings.HasPrefix(config.StaticURL(), "http") {
   203  		origins = append(origins, strings.TrimSuffix(config.StaticURL(), "/static"))
   204  	}
   205  	c := cors.New(cors.Options{
   206  		AllowedOrigins: origins,
   207  		//Debug:          true,
   208  	})
   209  	fn := func(w http.ResponseWriter, r *http.Request) {
   210  		c.HandlerFunc(w, r)
   211  		h.ServeHTTP(w, r)
   212  	}
   213  	return http.HandlerFunc(fn)
   214  }
   215  
   216  func redirecter(h http.Handler) http.Handler {
   217  	fn := func(w http.ResponseWriter, r *http.Request) {
   218  		if r.Host == "lekcije.herokuapp.com" {
   219  			http.Redirect(w, r, config.WebURL()+r.RequestURI, http.StatusMovedPermanently)
   220  			return
   221  		}
   222  		h.ServeHTTP(w, r)
   223  	}
   224  	return http.HandlerFunc(fn)
   225  }
   226  
   227  func setAuthorizationContext(h http.Handler) http.Handler {
   228  	fn := func(w http.ResponseWriter, r *http.Request) {
   229  		auth, err := ParseAuthorizationHeader(r.Header.Get("authorization"))
   230  		if err != nil {
   231  			h.ServeHTTP(w, r)
   232  			return
   233  		}
   234  		r = r.WithContext(context_data.SetAPIToken(r.Context(), strings.TrimSpace(auth)))
   235  		h.ServeHTTP(w, r)
   236  	}
   237  	return http.HandlerFunc(fn)
   238  }
   239  
   240  func ParseAuthorizationHeader(header string) (string, error) {
   241  	// Authorization: Bearer <token>
   242  	auth := strings.Split(header, " ")
   243  	if len(auth) < 2 || strings.ToLower(auth[0]) != "bearer" {
   244  		return "", fmt.Errorf("header value is not valid")
   245  	}
   246  	return auth[1], nil
   247  }