github.com/chenbh/concourse/v6@v6.4.2/skymarshal/skyserver/skyserver.go (about) 1 package skyserver 2 3 import ( 4 "context" 5 "crypto/rand" 6 "encoding/base64" 7 "encoding/hex" 8 "encoding/json" 9 "errors" 10 "fmt" 11 "net/http" 12 "net/url" 13 "strings" 14 "time" 15 16 "code.cloudfoundry.org/lager" 17 "github.com/chenbh/concourse/v6/skymarshal/token" 18 "golang.org/x/oauth2" 19 "gopkg.in/square/go-jose.v2/jwt" 20 ) 21 22 type SkyConfig struct { 23 Logger lager.Logger 24 TokenMiddleware token.Middleware 25 TokenParser token.Parser 26 OAuthConfig *oauth2.Config 27 HTTPClient *http.Client 28 } 29 30 func NewSkyHandler(server *SkyServer) http.Handler { 31 handler := http.NewServeMux() 32 handler.HandleFunc("/sky/login", server.Login) 33 handler.HandleFunc("/sky/logout", server.Logout) 34 handler.HandleFunc("/sky/callback", server.Callback) 35 return handler 36 } 37 38 func NewSkyServer(config *SkyConfig) (*SkyServer, error) { 39 return &SkyServer{config}, nil 40 } 41 42 type SkyServer struct { 43 config *SkyConfig 44 } 45 46 func (s *SkyServer) Login(w http.ResponseWriter, r *http.Request) { 47 48 logger := s.config.Logger.Session("login") 49 50 tokenString := s.config.TokenMiddleware.GetAuthToken(r) 51 if tokenString == "" { 52 s.NewLogin(w, r) 53 return 54 } 55 56 redirectURI := r.FormValue("redirect_uri") 57 if redirectURI == "" { 58 redirectURI = "/" 59 } 60 61 parts := strings.Split(tokenString, " ") 62 63 if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") { 64 logger.Info("failed-to-parse-cookie") 65 s.NewLogin(w, r) 66 return 67 } 68 69 expiry, err := s.config.TokenParser.ParseExpiry(parts[1]) 70 if err != nil { 71 logger.Error("failed-to-parse-expiration", err) 72 s.NewLogin(w, r) 73 return 74 } 75 nowWithLeeway := time.Now().Add(-jwt.DefaultLeeway) 76 if expiry.Before(nowWithLeeway) { 77 logger.Info("token-is-expired") 78 s.NewLogin(w, r) 79 return 80 } 81 82 oauth2Token := &oauth2.Token{ 83 TokenType: parts[0], 84 AccessToken: parts[1], 85 Expiry: expiry, 86 } 87 88 s.Redirect(w, r, oauth2Token, redirectURI) 89 } 90 91 func (s *SkyServer) NewLogin(w http.ResponseWriter, r *http.Request) { 92 93 logger := s.config.Logger.Session("new-login") 94 95 redirectURI := r.FormValue("redirect_uri") 96 if redirectURI == "" { 97 redirectURI = "/" 98 } 99 100 stateToken := encode(stateToken{ 101 RedirectURI: redirectURI, 102 Entropy: randomString(), 103 }) 104 105 err := s.config.TokenMiddleware.SetStateToken(w, stateToken, time.Now().Add(time.Hour)) 106 if err != nil { 107 logger.Error("invalid-state-token", err) 108 w.WriteHeader(http.StatusInternalServerError) 109 return 110 } 111 112 authCodeURL := s.config.OAuthConfig.AuthCodeURL(stateToken, oauth2.AccessTypeOffline) 113 114 http.Redirect(w, r, authCodeURL, http.StatusTemporaryRedirect) 115 } 116 117 func (s *SkyServer) Callback(w http.ResponseWriter, r *http.Request) { 118 119 logger := s.config.Logger.Session("callback") 120 121 if errMsg, errDesc := r.FormValue("error"), r.FormValue("error_description"); errMsg != "" { 122 logger.Error("failed-with-callback-error", errors.New(errMsg+" : "+errDesc)) 123 http.Error(w, errMsg, http.StatusBadRequest) 124 return 125 } 126 127 stateToken := s.config.TokenMiddleware.GetStateToken(r) 128 if stateToken == "" { 129 logger.Error("failed-with-invalid-state-token", errors.New("state token is empty")) 130 http.Error(w, "invalid state token", http.StatusBadRequest) 131 return 132 } 133 134 if stateToken != r.FormValue("state") { 135 logger.Error("failed-with-unexpected-state-token", errors.New("state token does not match")) 136 http.Error(w, "unexpected state token", http.StatusBadRequest) 137 return 138 } 139 140 s.config.TokenMiddleware.UnsetStateToken(w) 141 142 ctx := context.WithValue(r.Context(), oauth2.HTTPClient, s.config.HTTPClient) 143 144 dexToken, err := s.config.OAuthConfig.Exchange(ctx, r.FormValue("code")) 145 if err != nil { 146 logger.Error("failed-to-fetch-dex-token", err) 147 switch e := err.(type) { 148 case *oauth2.RetrieveError: 149 http.Error(w, string(e.Body), e.Response.StatusCode) 150 return 151 default: 152 http.Error(w, err.Error(), http.StatusBadRequest) 153 return 154 } 155 } 156 157 s.Redirect(w, r, dexToken, decode(stateToken).RedirectURI) 158 } 159 160 func (s *SkyServer) Redirect(w http.ResponseWriter, r *http.Request, oauth2Token *oauth2.Token, redirectURI string) { 161 logger := s.config.Logger.Session("redirect") 162 163 redirectURL, err := url.ParseRequestURI(redirectURI) 164 if err != nil { 165 logger.Error("failed-to-parse-redirect-url", err) 166 w.WriteHeader(http.StatusBadRequest) 167 return 168 } 169 170 if redirectURL.Host != "" { 171 logger.Error("invalid-redirect", fmt.Errorf("Unsupported redirect uri: %s", redirectURI)) 172 w.WriteHeader(http.StatusBadRequest) 173 return 174 } 175 176 err = s.config.TokenMiddleware.SetAuthToken(w, oauth2Token.TokenType+" "+oauth2Token.AccessToken, oauth2Token.Expiry) 177 if err != nil { 178 logger.Error("failed-to-set-auth-token", err) 179 w.WriteHeader(http.StatusInternalServerError) 180 return 181 } 182 183 csrfToken := randomString() 184 185 err = s.config.TokenMiddleware.SetCSRFToken(w, csrfToken, oauth2Token.Expiry) 186 if err != nil { 187 logger.Error("failed-to-set-state-token", err) 188 w.WriteHeader(http.StatusInternalServerError) 189 return 190 } 191 192 params := redirectURL.Query() 193 params.Set("csrf_token", csrfToken) 194 195 http.Redirect(w, r, redirectURL.EscapedPath()+"?"+params.Encode(), http.StatusTemporaryRedirect) 196 } 197 198 func (s *SkyServer) Logout(w http.ResponseWriter, r *http.Request) { 199 s.config.TokenMiddleware.UnsetAuthToken(w) 200 s.config.TokenMiddleware.UnsetCSRFToken(w) 201 } 202 203 type stateToken struct { 204 RedirectURI string `json:"redirect_uri"` 205 Entropy string `json:"entropy"` 206 } 207 208 func encode(token stateToken) string { 209 json, _ := json.Marshal(token) 210 211 return base64.StdEncoding.EncodeToString(json) 212 } 213 214 func decode(raw string) stateToken { 215 data, _ := base64.StdEncoding.DecodeString(raw) 216 217 var token stateToken 218 json.Unmarshal(data, &token) 219 return token 220 } 221 222 func randomString() string { 223 bytes := make([]byte, 32) 224 rand.Read(bytes) 225 return hex.EncodeToString(bytes) 226 }