decred.org/dcrdex@v1.0.5/client/webserver/middleware.go (about)

     1  // This code is available on the terms of the project LICENSE.md file,
     2  // also available online at https://blueoakcouncil.org/license/1.0.0.
     3  
     4  package webserver
     5  
     6  import (
     7  	"context"
     8  	"encoding/hex"
     9  	"errors"
    10  	"fmt"
    11  	"net/http"
    12  
    13  	"decred.org/dcrdex/dex"
    14  	"decred.org/dcrdex/dex/order"
    15  	"github.com/go-chi/chi/v5"
    16  )
    17  
    18  type ctxID int
    19  
    20  const (
    21  	ctxOID ctxID = iota
    22  	ctxHost
    23  )
    24  
    25  // securityMiddleware adds security headers to the server responses.
    26  func (s *WebServer) securityMiddleware(next http.Handler) http.Handler {
    27  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    28  		w.Header().Set("x-frame-options", "DENY")
    29  		w.Header().Set("X-XSS-Protection", "1; mode=block")
    30  		w.Header().Set("X-Content-Type-Options", "nosniff")
    31  		w.Header().Set("Referrer-Policy", "no-referrer")
    32  		w.Header().Set("Content-Security-Policy", s.csp)
    33  		w.Header().Set("Permissions-Policy", "geolocation=(), midi=(), sync-xhr=(self), microphone=(), camera=(), magnetometer=(), gyroscope=(), fullscreen=(self), payment=()")
    34  		next.ServeHTTP(w, r)
    35  	})
    36  }
    37  
    38  // authMiddleware checks incoming requests for cookie-based information
    39  // including the auth token. Use extractUserInfo to access the *userInfo in
    40  // downstream handlers. This should be used with care since it involves a call
    41  // to (*Core).User, which can be expensive.
    42  func (s *WebServer) authMiddleware(next http.Handler) http.Handler {
    43  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    44  		ctx := context.WithValue(r.Context(), ctxKeyUserInfo, &userInfo{
    45  			Authed:           s.isAuthed(r),
    46  			PasswordIsCached: s.isPasswordCached(r),
    47  		})
    48  		next.ServeHTTP(w, r.WithContext(ctx))
    49  	})
    50  }
    51  
    52  // extractBooleanCookie extracts the cookie value with key k from the Request,
    53  // and interprets the value as true only if it's equal to the string "1".
    54  func extractBooleanCookie(r *http.Request, k string, defaultVal bool) bool {
    55  	cookie, err := r.Cookie(k)
    56  	switch {
    57  	// Dark mode is the default
    58  	case err == nil:
    59  		return cookie.Value == "1"
    60  	case errors.Is(err, http.ErrNoCookie):
    61  	default:
    62  		log.Errorf("Cookie %q retrieval error: %v", k, err)
    63  	}
    64  	return defaultVal
    65  }
    66  
    67  // requireInit ensures that the core app is initialized before allowing the
    68  // incoming request to proceed. Redirects to the register page if the app is
    69  // not initialized.
    70  func (s *WebServer) requireInit(next http.Handler) http.Handler {
    71  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    72  		if !s.core.IsInitialized() {
    73  			http.Redirect(w, r, initRoute, http.StatusSeeOther)
    74  			return
    75  		}
    76  		next.ServeHTTP(w, r)
    77  	})
    78  }
    79  
    80  // requireNotInit ensures that the core app is not initialized before allowing
    81  // the incoming request to proceed. Redirects to the login page if the app is
    82  // initialized and the user is not logged in. If logged in, directs to the
    83  // wallets page.
    84  func (s *WebServer) requireNotInit(next http.Handler) http.Handler {
    85  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    86  		if s.core.IsInitialized() {
    87  			route := loginRoute
    88  			if extractUserInfo(r).Authed {
    89  				route = walletsRoute
    90  			}
    91  			http.Redirect(w, r, route, http.StatusSeeOther)
    92  			return
    93  		}
    94  		next.ServeHTTP(w, r)
    95  	})
    96  }
    97  
    98  // rejectUninited is like requireInit except that it responds with an error
    99  // instead of redirecting to the register path.
   100  func (s *WebServer) rejectUninited(next http.Handler) http.Handler {
   101  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   102  		if !s.core.IsInitialized() {
   103  			http.Error(w, http.StatusText(http.StatusPreconditionRequired), http.StatusPreconditionRequired)
   104  			return
   105  		}
   106  		next.ServeHTTP(w, r)
   107  	})
   108  }
   109  
   110  // requireLogin ensures that the user is authenticated (has logged in) before
   111  // allowing the incoming request to proceed. Redirects to login page if user is
   112  // not logged in. This check should typically be performed after checking that
   113  // the app is initialized.
   114  func (s *WebServer) requireLogin(next http.Handler) http.Handler {
   115  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   116  		if !s.isAuthed(r) {
   117  			http.Redirect(w, r, loginRoute, http.StatusSeeOther)
   118  			return
   119  		}
   120  		next.ServeHTTP(w, r)
   121  	})
   122  }
   123  
   124  // rejectUnauthed is like requireLogin except that it responds with an error
   125  // instead of redirecting to the login path.
   126  func (s *WebServer) rejectUnauthed(next http.Handler) http.Handler {
   127  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   128  		if !s.isAuthed(r) {
   129  			http.Error(w, "not authorized - login first", http.StatusUnauthorized)
   130  			return
   131  		}
   132  		next.ServeHTTP(w, r)
   133  	})
   134  }
   135  
   136  // requireDEXConnection ensures that the user has completely registered with at
   137  // least 1 DEX before allowing the incoming request to proceed. Redirects to the
   138  // register page if the user has not connected any DEX.
   139  func (s *WebServer) requireDEXConnection(next http.Handler) http.Handler {
   140  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   141  		if len(s.core.Exchanges()) == 0 {
   142  			http.Redirect(w, r, registerRoute, http.StatusSeeOther)
   143  			return
   144  		}
   145  		next.ServeHTTP(w, r)
   146  	})
   147  }
   148  
   149  // dexHostCtx embeds the host into the request context.
   150  func dexHostCtx(next http.Handler) http.Handler {
   151  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   152  		host := chi.URLParam(r, "host")
   153  		ctx := context.WithValue(r.Context(), ctxHost, host)
   154  		next.ServeHTTP(w, r.WithContext(ctx))
   155  	})
   156  }
   157  
   158  // getHostCtx interprets the context value at ctxHost as a string host.
   159  func getHostCtx(r *http.Request) (string, error) {
   160  	untypedHost := r.Context().Value(ctxHost)
   161  	if untypedHost == nil {
   162  		return "", errors.New("host not set in request")
   163  	}
   164  	host, ok := untypedHost.(string)
   165  	if !ok {
   166  		return "", errors.New("type assertion failed")
   167  	}
   168  	return host, nil
   169  }
   170  
   171  // orderIDCtx embeds order ID into the request context.
   172  func orderIDCtx(next http.Handler) http.Handler {
   173  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   174  		oid := chi.URLParam(r, "oid")
   175  		ctx := context.WithValue(r.Context(), ctxOID, oid)
   176  		next.ServeHTTP(w, r.WithContext(ctx))
   177  	})
   178  }
   179  
   180  // getOrderIDCtx interprets the context value at ctxOID as a dex.Bytes order ID.
   181  func getOrderIDCtx(r *http.Request) (dex.Bytes, error) {
   182  	untypedOID := r.Context().Value(ctxOID)
   183  	if untypedOID == nil {
   184  		log.Errorf("nil value for order ID context value")
   185  	}
   186  	hexID, ok := untypedOID.(string)
   187  	if !ok {
   188  		log.Errorf("getOrderIDCtx type assertion failed. Expected string, got %T", untypedOID)
   189  		return nil, fmt.Errorf("type assertion failed")
   190  	}
   191  
   192  	if len(hexID) != order.OrderIDSize*2 {
   193  		log.Errorf("getOrderIDCtx received order ID string of wrong length. wanted %d, got %d",
   194  			order.OrderIDSize*2, len(hexID))
   195  		return nil, fmt.Errorf("invalid order ID")
   196  	}
   197  	oidB, err := hex.DecodeString(hexID)
   198  	if err != nil {
   199  		log.Errorf("getOrderIDCtx received invalid hex for order ID %q", hexID)
   200  		return nil, fmt.Errorf("")
   201  	}
   202  	return oidB, nil
   203  }