gitlab.com/ignitionrobotics/web/ign-go@v1.0.0-rc4/router_middleware.go (about)

     1  package ign
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"github.com/auth0/go-jwt-middleware"
     7  	"github.com/codegangsta/negroni"
     8  	"github.com/dgrijalva/jwt-go"
     9  	"github.com/golang/protobuf/proto"
    10  	"github.com/jinzhu/gorm"
    11  	"github.com/jpillora/go-ogle-analytics"
    12  	"github.com/mssola/user_agent"
    13  	"github.com/satori/go.uuid"
    14  	"log"
    15  	"net/http"
    16  	"reflect"
    17  	"strings"
    18  	"time"
    19  )
    20  
    21  // Handler represents an HTTP Handler that can also return a ErrMsg
    22  // See https://blog.golang.org/error-handling-and-go
    23  type Handler func(*gorm.DB, http.ResponseWriter, *http.Request) *ErrMsg
    24  
    25  // HandlerWithResult represents an HTTP Handler that that has a result
    26  type HandlerWithResult func(tx *gorm.DB, w http.ResponseWriter,
    27  	r *http.Request) (interface{}, *ErrMsg)
    28  
    29  // TypeJSONResult represents a function result that can be exported to JSON
    30  type TypeJSONResult struct {
    31  	wrapperField string
    32  	fn           HandlerWithResult
    33  	wrapWithTx   bool
    34  }
    35  
    36  // ProtoResult provides protobuf serialization for handler results
    37  type ProtoResult HandlerWithResult
    38  
    39  // JSONResult provides JSON serialization for handler results
    40  func JSONResult(handler HandlerWithResult) TypeJSONResult {
    41  	return TypeJSONResult{"", handler, true}
    42  }
    43  
    44  // IsBotHandler decides which handler to use whether the request was made by a
    45  // bot or a user.
    46  func IsBotHandler(botHandler http.Handler, userHandler http.Handler) http.Handler {
    47  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    48  		var handler http.Handler
    49  		ua := user_agent.New(r.Header.Get("User-Agent"))
    50  		if (ua.Bot()) {
    51  			handler = botHandler
    52  		} else {
    53  			handler = userHandler
    54  		}
    55  		handler.ServeHTTP(w, r)
    56  	})
    57  }
    58  
    59  // JSONResultNoTx provides JSON serialization for handler results
    60  func JSONResultNoTx(handler HandlerWithResult) TypeJSONResult {
    61  	return TypeJSONResult{"", handler, false}
    62  }
    63  
    64  // JSONListResult provides JSON serialization for handler results that are
    65  // slices of objects.
    66  func JSONListResult(wrapper string, handler HandlerWithResult) TypeJSONResult {
    67  	return TypeJSONResult{wrapper, handler, true}
    68  }
    69  
    70  /////////////////////////////////////////////////
    71  
    72  func (fn Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    73  	txFunc := dbTransactionWrapper(handlerToHandlerWithResult(fn))
    74  	if _, err := txFunc(w, r); err != nil {
    75  		reportJSONError(w, r, *err)
    76  	}
    77  }
    78  
    79  /////////////////////////////////////////////////
    80  
    81  // basicHandlerWith represents a basic handler function that returns a result and an error.
    82  type basicHandlerWithResult func(w http.ResponseWriter, r *http.Request) (interface{}, *ErrMsg)
    83  
    84  // IsSQLTxError checks if the given error is a sqlTx error.
    85  // Note: we need to do that by testing its error message.
    86  func IsSQLTxError(err error) bool {
    87  	return err != nil && strings.ToLower(err.Error()) == "sql: transaction has already been committed or rolled back"
    88  }
    89  
    90  // dbTransactionWrapper handles opening and closing of a DB Transaction.
    91  // It invokes the given handler with the created TX.
    92  // By using this wrapper , real handlers won't need to open and close the TX.
    93  // IMPORTANT NOTE: note that once you write data (not headers) into the
    94  // ResponseWriter then the status code is set to 200 (OK). Keep that in mind
    95  // when coding your Handler logic (eg. when using fmt.Fprint(w, ...))
    96  func dbTransactionWrapper(handler HandlerWithResult) basicHandlerWithResult {
    97  	return func(w http.ResponseWriter, r *http.Request) (interface{}, *ErrMsg) {
    98  		tx := gServer.Db.Begin()
    99  		if tx.Error != nil {
   100  			return nil, NewErrorMessageWithBase(ErrorNoDatabase, tx.Error)
   101  		}
   102  
   103  		defer func() {
   104  			// check for panic (to close sql connections)
   105  			if p := recover(); p != nil {
   106  				tx.Rollback()
   107  				panic(p) // re-throw panic after Rollback
   108  			}
   109  		}()
   110  		result, em := handler(tx, w, r)
   111  		if em != nil {
   112  			tx.Rollback()
   113  		} else {
   114  			// Commit DB transaction
   115  			err := tx.Commit().Error
   116  			if err != nil && !IsSQLTxError(err) {
   117  				// re-throw error if different than TX already committed/rollbacked err
   118  				result, em = nil, NewErrorMessageWithBase(ErrorNoDatabase, err)
   119  			}
   120  		}
   121  		return result, em
   122  	}
   123  }
   124  
   125  // handlerToHandlerWithResult converts an ign.Handler to an
   126  // ign.HandlerWithResult.
   127  func handlerToHandlerWithResult(handler Handler) HandlerWithResult {
   128  	return func(tx *gorm.DB, w http.ResponseWriter, r *http.Request) (interface{}, *ErrMsg) {
   129  		err := handler(tx, w, r)
   130  		return nil, err
   131  	}
   132  }
   133  
   134  func (t TypeJSONResult) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   135  	var txFunc basicHandlerWithResult
   136  	if t.wrapWithTx {
   137  		txFunc = dbTransactionWrapper(t.fn)
   138  	} else {
   139  		txFunc = func(w http.ResponseWriter, r *http.Request) (interface{}, *ErrMsg) {
   140  			return t.fn(gServer.Db, w, r)
   141  		}
   142  	}
   143  	result, err := txFunc(w, r)
   144  	if err != nil {
   145  		reportJSONError(w, r, *err)
   146  		return
   147  	}
   148  
   149  	var data interface{}
   150  	// Is there any wrapper field to cut off ?
   151  	if t.wrapperField != "" {
   152  		value := reflect.ValueOf(result)
   153  		fieldValue := reflect.Indirect(value).FieldByName(t.wrapperField)
   154  		data = fieldValue.Interface()
   155  		// If the underlying data is an empty slice then force the creation of
   156  		// an empty json `[]` as output
   157  		if fieldValue.Kind() == reflect.Slice && fieldValue.Len() == 0 {
   158  			data = make([]string, 0)
   159  		}
   160  	} else {
   161  		data = result
   162  	}
   163  	w.Header().Set("Content-Type", "application/json")
   164  	// Marshal the response into a JSON
   165  	if err := json.NewEncoder(w).Encode(data); err != nil {
   166  		em := NewErrorMessageWithBase(ErrorMarshalJSON, err)
   167  		reportJSONError(w, r, *em)
   168  		return
   169  	}
   170  }
   171  
   172  /////////////////////////////////////////////////
   173  
   174  func (fn ProtoResult) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   175  	txFunc := dbTransactionWrapper(HandlerWithResult(fn))
   176  	result, err := txFunc(w, r)
   177  	if err != nil {
   178  		reportJSONError(w, r, *err)
   179  		return
   180  	}
   181  
   182  	// Marshal the protobuf data and write it out.
   183  	var pm = result.(proto.Message)
   184  	data, e := proto.Marshal(pm)
   185  	if e != nil {
   186  		em := NewErrorMessageWithBase(ErrorMarshalProto, e)
   187  		reportJSONError(w, r, *em)
   188  		return
   189  	}
   190  	w.Header().Set("Content-Type", "application/arraybuffer")
   191  	w.Write(data)
   192  }
   193  
   194  /////////////////////////////////////////////////
   195  
   196  // ReportJSONError logs an error message and return an HTTP error including
   197  // JSON payload
   198  func reportJSONError(w http.ResponseWriter, r *http.Request, errMsg ErrMsg) {
   199  	errMsg.UserAgent = r.UserAgent()
   200  	errMsg.RemoteAddress = getIPAddress(r)
   201  	if errMsg.Route == "" {
   202  		errMsg.Route = r.Method + " " + r.RequestURI
   203  	}
   204  	// Report the error to rollbar, and output to console
   205  	LoggerFromRequest(r).Error(errMsg, r)
   206  
   207  	output, err := json.Marshal(errMsg)
   208  	if err != nil {
   209  		reportError(w, "Unable to marshal JSON", http.StatusServiceUnavailable)
   210  		return
   211  	}
   212  
   213  	http.Error(w, string(output), errMsg.StatusCode)
   214  }
   215  
   216  // reportError logs an error message and return an HTTP error
   217  func reportError(w http.ResponseWriter, msg string, errCode int) {
   218  	log.Println("Error in [" + Trace(3) + "]\n\t" + msg)
   219  	http.Error(w, msg, errCode)
   220  }
   221  
   222  /////////////////////////////////////////////////
   223  
   224  // JWTMiddlewareIgn wraps jwtmiddleware.JWTMiddleware so that we can create
   225  // a custom AccessTokenHandler that first checks for a Private-Token and then
   226  // checks for a JWT if the Private-Token doesn't exist.
   227  type JWTMiddlewareIgn struct {
   228  	*jwtmiddleware.JWTMiddleware
   229  }
   230  
   231  // AccessTokenHandler first checks for a Private-Token and then
   232  // checks for a JWT if the Private-Token doesn't exist.
   233  func (m *JWTMiddlewareIgn) AccessTokenHandler(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
   234  	// Check if a Private-Token is used, which will supercede a JWT token.
   235  	if token := r.Header.Get("Private-Token"); len(token) > 0 {
   236  
   237  		var errorMsg *ErrMsg
   238  
   239  		tx := gServer.UsersDb.Begin()
   240  		defer func() {
   241  			// check for panic (to close sql connections)
   242  			if p := recover(); p != nil {
   243  				tx.Rollback()
   244  				panic(p) // re-throw panic after Rollback
   245  			}
   246  		}()
   247  
   248  		var accessToken *AccessToken
   249  		if tx.Error != nil {
   250  			errorMsg = NewErrorMessageWithBase(ErrorNoDatabase, tx.Error)
   251  		} else {
   252  			accessToken, errorMsg = ValidateAccessToken(token, tx)
   253  		}
   254  
   255  		if errorMsg != nil {
   256  			logger := NewLoggerWithRollbarVerbosity("AccessTokenHandler", gServer.LogToStd, gServer.LogVerbosity, gServer.RollbarLogVerbosity)
   257  			logger.Error(errorMsg)
   258  			m.Options.ErrorHandler(w, r, errorMsg.Msg)
   259  			tx.Rollback()
   260  			return
   261  		}
   262  
   263  		if accessToken.LastUsed == nil {
   264  			accessToken.LastUsed = new(time.Time)
   265  		}
   266  
   267  		*(accessToken.LastUsed) = time.Now()
   268  		tx.Save(accessToken)
   269  		tx.Commit()
   270  
   271  		next(w, r)
   272  	} else {
   273  		m.HandlerWithNext(w, r, next)
   274  	}
   275  }
   276  
   277  // CreateJWTOptionalMiddleware creates and returns a middleware that
   278  // allows requests with optional JWT tokens.
   279  func CreateJWTOptionalMiddleware(s *Server) negroni.HandlerFunc {
   280  	// See https://github.com/auth0/go-jwt-middleware
   281  	opt := jwtmiddleware.New(
   282  		jwtmiddleware.Options{
   283  			Debug:               false,
   284  			CredentialsOptional: true,
   285  			SigningMethod:       jwt.SigningMethodRS256,
   286  			ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
   287  				return jwt.ParseRSAPublicKeyFromPEM([]byte(s.pemKeyString))
   288  			},
   289  		})
   290  	return negroni.HandlerFunc(opt.HandlerWithNext)
   291  }
   292  
   293  // CreateJWTRequiredMiddleware creates and returns a middleware that
   294  // rejects requests that do not have a JWT token.
   295  func CreateJWTRequiredMiddleware(s *Server) negroni.HandlerFunc {
   296  	req := &JWTMiddlewareIgn{jwtmiddleware.New(jwtmiddleware.Options{
   297  		Debug:               false,
   298  		SigningMethod:       jwt.SigningMethodRS256,
   299  		CredentialsOptional: false,
   300  		ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
   301  			return jwt.ParseRSAPublicKeyFromPEM([]byte(s.pemKeyString))
   302  		},
   303  	})}
   304  
   305  	return negroni.HandlerFunc(req.AccessTokenHandler)
   306  }
   307  
   308  // Middleware to ensure the DB instance exists.
   309  // By having this middleware, then any route handler can safely assume the DB
   310  // is present.
   311  func requireDBMiddleware(w http.ResponseWriter, r *http.Request,
   312  	next http.HandlerFunc) {
   313  
   314  	if gServer.Db == nil {
   315  		errMsg := ErrorMessage(ErrorNoDatabase)
   316  		reportJSONError(w, r, errMsg)
   317  	} else {
   318  		next(w, r)
   319  	}
   320  }
   321  
   322  // addCORSheadersMiddleware adds CORS related headers to an http response.
   323  func addCORSheadersMiddleware(w http.ResponseWriter, r *http.Request,
   324  	next http.HandlerFunc) {
   325  	addCORSheaders(w)
   326  	next(w, r)
   327  }
   328  
   329  // addCORSheaders adds the required Access Control headers to the HTTP response
   330  func addCORSheaders(w http.ResponseWriter) {
   331  	w.Header().Set("Access-Control-Allow-Methods",
   332  		"GET, HEAD, POST, PUT, PATCH, DELETE")
   333  
   334  	w.Header().Set("Access-Control-Allow-Credentials", "true")
   335  
   336  	w.Header().Set("Access-Control-Allow-Headers",
   337  		`Accept, Accept-Language, Content-Language, Origin,
   338                    Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token,
   339                    Authorization`)
   340  	w.Header().Set("Access-Control-Allow-Origin", "*")
   341  
   342  	w.Header().Set("Access-Control-Expose-Headers", "Link, X-Total-Count, X-Ign-Resource-Version")
   343  }
   344  
   345  // getRequestID gets the request's X-Request-ID header OR, if the header is empty,
   346  // returns a generated UUID string.
   347  func getRequestID(r *http.Request) string {
   348  	reqID := r.Header.Get("X-Request-ID")
   349  	if reqID == "" {
   350  		reqID = uuid.NewV4().String()
   351  	}
   352  	return reqID
   353  }
   354  
   355  /////////////////////////////////////////////////
   356  // logger creates a middleware used to output HTTP requests.
   357  func logger(inner http.Handler, name string) http.Handler {
   358  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   359  		start := time.Now()
   360  		reqID := getRequestID(r)
   361  		logger := NewLoggerWithRollbarVerbosity(reqID, gServer.LogToStd, gServer.LogVerbosity, gServer.RollbarLogVerbosity)
   362  		logCtx := NewContextWithLogger(r.Context(), logger)
   363  
   364  		logger.Info(fmt.Sprintf("Incoming req: %s %s %s",
   365  			r.Method,
   366  			r.RequestURI,
   367  			name,
   368  		))
   369  		// run the server logic
   370  		inner.ServeHTTP(w, r.WithContext(logCtx))
   371  		// log output
   372  		logger.Info(fmt.Sprintf("Finished req: %s %s %s %s",
   373  			r.Method,
   374  			r.RequestURI,
   375  			name,
   376  			time.Since(start),
   377  		))
   378  	})
   379  }
   380  
   381  /////////////////////////////////////////////////
   382  // newGaEventTracking creates a new middleware to send events to Google Analytics.
   383  // Events will be automatically created using route information.
   384  // This middleware requires IGN_GA_TRACKING_ID and IGN_GA_APP_NAME
   385  // env vars.
   386  func newGaEventTracking(routeName string) negroni.HandlerFunc {
   387  	return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
   388  		next(w, r)
   389  
   390  		// Track event with GA, if enabled
   391  		if gServer.GaAppName == "" || gServer.GaTrackingID == "" {
   392  			return
   393  		}
   394  		c, err := ga.NewClient(gServer.GaTrackingID)
   395  		if err != nil {
   396  			LoggerFromRequest(r).Error("Error creating GA client", err, r)
   397  			return
   398  		}
   399  		c.DataSource(gServer.GaAppName)
   400  		c.ApplicationName(gServer.GaAppName)
   401  		cat := gServer.GaCategoryPrefix + routeName
   402  		action := r.Method
   403  		e := ga.NewEvent(cat, action).Label(r.URL.String())
   404  		if err := c.Send(e); err != nil {
   405  			fmt.Println("Error while sending event to GA", err)
   406  		}
   407  	}
   408  }