github.com/fanux/shipyard@v0.0.0-20161009071005-6515ce223235/controller/middleware/access/access.go (about)

     1  package access
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"strings"
     7  
     8  	"github.com/Sirupsen/logrus"
     9  	"github.com/shipyard/shipyard/auth"
    10  	"github.com/shipyard/shipyard/controller/manager"
    11  )
    12  
    13  var (
    14  	logger = logrus.New()
    15  )
    16  
    17  func defaultDeniedHandler(w http.ResponseWriter, r *http.Request) {
    18  	http.Error(w, "access denied", http.StatusForbidden)
    19  }
    20  
    21  type AccessRequired struct {
    22  	deniedHandler http.Handler
    23  	manager       manager.Manager
    24  	acls          []*auth.ACL
    25  }
    26  
    27  func NewAccessRequired(m manager.Manager) *AccessRequired {
    28  	acls := auth.DefaultACLs()
    29  	a := &AccessRequired{
    30  		deniedHandler: http.HandlerFunc(defaultDeniedHandler),
    31  		manager:       m,
    32  		acls:          acls,
    33  	}
    34  	return a
    35  }
    36  
    37  func (a *AccessRequired) Handler(h http.Handler) http.Handler {
    38  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    39  		err := a.handleRequest(w, r)
    40  		if err != nil {
    41  			logger.Warnf("unauthorized request for %s from %s", r.URL.Path, r.RemoteAddr)
    42  			return
    43  		}
    44  		h.ServeHTTP(w, r)
    45  	})
    46  }
    47  
    48  func (a *AccessRequired) handleRequest(w http.ResponseWriter, r *http.Request) error {
    49  	valid := false
    50  	authHeader := r.Header.Get("X-Access-Token")
    51  	parts := strings.Split(authHeader, ":")
    52  	if len(parts) == 2 {
    53  		// validate
    54  		u := parts[0]
    55  		token := parts[1]
    56  		if err := a.manager.VerifyAuthToken(u, token); err == nil {
    57  			acct, err := a.manager.Account(u)
    58  			if err != nil {
    59  				return err
    60  			}
    61  			// check role
    62  			valid = a.checkAccess(acct, r.URL.Path, r.Method)
    63  		}
    64  	} else { // only check access for users; not service keys
    65  		valid = true
    66  	}
    67  
    68  	if !valid {
    69  		a.deniedHandler.ServeHTTP(w, r)
    70  		return fmt.Errorf("access denied %s", r.RemoteAddr)
    71  	}
    72  
    73  	return nil
    74  }
    75  
    76  func (a *AccessRequired) checkRule(rule *auth.AccessRule, path, method string) bool {
    77  	// check wildcard
    78  	if rule.Path == "*" {
    79  		return true
    80  	}
    81  
    82  	// check path
    83  	if strings.HasPrefix(path, rule.Path) {
    84  		// check method
    85  		for _, m := range rule.Methods {
    86  			if m == method {
    87  				return true
    88  			}
    89  		}
    90  	}
    91  
    92  	return false
    93  }
    94  
    95  func (a *AccessRequired) checkRole(role string, path, method string) bool {
    96  	for _, acl := range a.acls {
    97  		// find role
    98  		if acl.RoleName == role {
    99  			for _, rule := range acl.Rules {
   100  				if a.checkRule(rule, path, method) {
   101  					return true
   102  				}
   103  			}
   104  		}
   105  	}
   106  
   107  	return false
   108  }
   109  func (a *AccessRequired) checkAccess(acct *auth.Account, path string, method string) bool {
   110  	// check roles
   111  	for _, role := range acct.Roles {
   112  		// check acls
   113  		if a.checkRole(role, path, method) {
   114  			return true
   115  		}
   116  	}
   117  
   118  	return false
   119  }
   120  
   121  func (a *AccessRequired) HandlerFuncWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
   122  	err := a.handleRequest(w, r)
   123  	session, _ := a.manager.Store().Get(r, a.manager.StoreKey())
   124  	username := session.Values["username"]
   125  	if err != nil {
   126  		logger.Warnf("access denied for %s to %s from %s", username, r.URL.Path, r.RemoteAddr)
   127  		return
   128  	}
   129  
   130  	if next != nil {
   131  		next(w, r)
   132  	}
   133  }