github.com/fanux/shipyard@v0.0.0-20161009071005-6515ce223235/controller/middleware/auth/auth.go (about)

     1  package auth
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"net/http"
     7  	"strings"
     8  
     9  	"github.com/Sirupsen/logrus"
    10  	"github.com/shipyard/shipyard/controller/manager"
    11  )
    12  
    13  var (
    14  	logger = logrus.New()
    15  )
    16  
    17  func defaultDeniedHostHandler(w http.ResponseWriter, r *http.Request) {
    18  	http.Error(w, "unauthorized", http.StatusUnauthorized)
    19  }
    20  
    21  type AuthRequired struct {
    22  	deniedHostHandler http.Handler
    23  	manager           manager.Manager
    24  	whitelistCIDRs    []string
    25  }
    26  
    27  func NewAuthRequired(m manager.Manager, whitelistCIDRs []string) *AuthRequired {
    28  	return &AuthRequired{
    29  		deniedHostHandler: http.HandlerFunc(defaultDeniedHostHandler),
    30  		manager:           m,
    31  		whitelistCIDRs:    whitelistCIDRs,
    32  	}
    33  }
    34  
    35  func (a *AuthRequired) Handler(h http.Handler) http.Handler {
    36  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    37  		err := a.handleRequest(w, r)
    38  		if err != nil {
    39  			logger.Warnf("unauthorized request for %s from %s", r.URL.Path, r.RemoteAddr)
    40  			return
    41  		}
    42  		h.ServeHTTP(w, r)
    43  	})
    44  }
    45  
    46  func (a *AuthRequired) isWhitelisted(addr string) (bool, error) {
    47  	parts := strings.Split(addr, ":")
    48  	src := parts[0]
    49  
    50  	srcIp := net.ParseIP(src)
    51  
    52  	// check each whitelisted ip
    53  	for _, c := range a.whitelistCIDRs {
    54  		_, ipNet, err := net.ParseCIDR(c)
    55  		if err != nil {
    56  			return false, err
    57  		}
    58  
    59  		if ipNet.Contains(srcIp) {
    60  			return true, nil
    61  		}
    62  	}
    63  
    64  	return false, nil
    65  }
    66  
    67  func (a *AuthRequired) handleRequest(w http.ResponseWriter, r *http.Request) error {
    68  	whitelisted, err := a.isWhitelisted(r.RemoteAddr)
    69  	if err != nil {
    70  		return err
    71  	}
    72  
    73  	if whitelisted {
    74  		return nil
    75  	}
    76  
    77  	valid := false
    78  	// service key takes priority
    79  	serviceKey := r.Header.Get("X-Service-Key")
    80  	if serviceKey != "" {
    81  		if err := a.manager.VerifyServiceKey(serviceKey); err == nil {
    82  			valid = true
    83  		}
    84  	} else { // check for authHeader
    85  		authHeader := r.Header.Get("X-Access-Token")
    86  		parts := strings.Split(authHeader, ":")
    87  		if len(parts) == 2 {
    88  			// validate
    89  			user := parts[0]
    90  			token := parts[1]
    91  			if err := a.manager.VerifyAuthToken(user, token); err == nil {
    92  				valid = true
    93  				// set current user
    94  				session, _ := a.manager.Store().Get(r, a.manager.StoreKey())
    95  				session.Values["username"] = user
    96  				session.Save(r, w)
    97  			}
    98  		}
    99  	}
   100  
   101  	if !valid {
   102  		a.deniedHostHandler.ServeHTTP(w, r)
   103  		return fmt.Errorf("unauthorized %s", r.RemoteAddr)
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  func (a *AuthRequired) HandlerFuncWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
   110  	err := a.handleRequest(w, r)
   111  
   112  	if err != nil {
   113  		logger.Warnf("unauthorized request for %s from %s", r.URL.Path, r.RemoteAddr)
   114  		return
   115  	}
   116  
   117  	if next != nil {
   118  		next(w, r)
   119  	}
   120  }