github.com/fanux/shipyard@v0.0.0-20161009071005-6515ce223235/controller/middleware/auth/auth.go (about) 1 package auth 2 3 import ( 4 "fmt" 5 "net" 6 "net/http" 7 "strings" 8 9 "github.com/Sirupsen/logrus" 10 "github.com/shipyard/shipyard/controller/manager" 11 ) 12 13 var ( 14 logger = logrus.New() 15 ) 16 17 func defaultDeniedHostHandler(w http.ResponseWriter, r *http.Request) { 18 http.Error(w, "unauthorized", http.StatusUnauthorized) 19 } 20 21 type AuthRequired struct { 22 deniedHostHandler http.Handler 23 manager manager.Manager 24 whitelistCIDRs []string 25 } 26 27 func NewAuthRequired(m manager.Manager, whitelistCIDRs []string) *AuthRequired { 28 return &AuthRequired{ 29 deniedHostHandler: http.HandlerFunc(defaultDeniedHostHandler), 30 manager: m, 31 whitelistCIDRs: whitelistCIDRs, 32 } 33 } 34 35 func (a *AuthRequired) Handler(h http.Handler) http.Handler { 36 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 37 err := a.handleRequest(w, r) 38 if err != nil { 39 logger.Warnf("unauthorized request for %s from %s", r.URL.Path, r.RemoteAddr) 40 return 41 } 42 h.ServeHTTP(w, r) 43 }) 44 } 45 46 func (a *AuthRequired) isWhitelisted(addr string) (bool, error) { 47 parts := strings.Split(addr, ":") 48 src := parts[0] 49 50 srcIp := net.ParseIP(src) 51 52 // check each whitelisted ip 53 for _, c := range a.whitelistCIDRs { 54 _, ipNet, err := net.ParseCIDR(c) 55 if err != nil { 56 return false, err 57 } 58 59 if ipNet.Contains(srcIp) { 60 return true, nil 61 } 62 } 63 64 return false, nil 65 } 66 67 func (a *AuthRequired) handleRequest(w http.ResponseWriter, r *http.Request) error { 68 whitelisted, err := a.isWhitelisted(r.RemoteAddr) 69 if err != nil { 70 return err 71 } 72 73 if whitelisted { 74 return nil 75 } 76 77 valid := false 78 // service key takes priority 79 serviceKey := r.Header.Get("X-Service-Key") 80 if serviceKey != "" { 81 if err := a.manager.VerifyServiceKey(serviceKey); err == nil { 82 valid = true 83 } 84 } else { // check for authHeader 85 authHeader := r.Header.Get("X-Access-Token") 86 parts := strings.Split(authHeader, ":") 87 if len(parts) == 2 { 88 // validate 89 user := parts[0] 90 token := parts[1] 91 if err := a.manager.VerifyAuthToken(user, token); err == nil { 92 valid = true 93 // set current user 94 session, _ := a.manager.Store().Get(r, a.manager.StoreKey()) 95 session.Values["username"] = user 96 session.Save(r, w) 97 } 98 } 99 } 100 101 if !valid { 102 a.deniedHostHandler.ServeHTTP(w, r) 103 return fmt.Errorf("unauthorized %s", r.RemoteAddr) 104 } 105 106 return nil 107 } 108 109 func (a *AuthRequired) HandlerFuncWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { 110 err := a.handleRequest(w, r) 111 112 if err != nil { 113 logger.Warnf("unauthorized request for %s from %s", r.URL.Path, r.RemoteAddr) 114 return 115 } 116 117 if next != nil { 118 next(w, r) 119 } 120 }