github.com/nais/knorten@v0.0.0-20240104110906-55926958e361/pkg/api/auth.go (about) 1 package api 2 3 import ( 4 "context" 5 "crypto/rand" 6 "database/sql" 7 "encoding/base64" 8 "encoding/hex" 9 "encoding/json" 10 "errors" 11 "fmt" 12 "net" 13 "net/http" 14 "strings" 15 "time" 16 17 "github.com/gin-contrib/sessions" 18 "github.com/gin-gonic/gin" 19 "github.com/golang-jwt/jwt/v4" 20 "github.com/google/uuid" 21 "github.com/nais/knorten/pkg/api/auth" 22 "k8s.io/utils/strings/slices" 23 ) 24 25 const ( 26 RedirectURICookie = "redirecturi" 27 OAuthStateCookie = "oauthstate" 28 sessionCookie = "knorten_session" 29 tokenLength = 32 30 sessionLength = 1 * time.Hour 31 ) 32 33 func (c *client) login(ctx *gin.Context) string { 34 host, _, err := net.SplitHostPort(ctx.Request.Host) 35 if err != nil { 36 host = ctx.Request.Host 37 } 38 39 redirectURI := ctx.Request.URL.Query().Get("redirect_uri") 40 ctx.SetCookie( 41 RedirectURICookie, 42 redirectURI, 43 time.Now().Add(30*time.Minute).Second(), 44 "/", 45 host, 46 true, 47 true, 48 ) 49 50 oauthState := uuid.New().String() 51 ctx.SetCookie( 52 OAuthStateCookie, 53 oauthState, 54 time.Now().Add(30*time.Minute).Second(), 55 "/", 56 host, 57 true, 58 true, 59 ) 60 61 return c.azureClient.AuthCodeURL(oauthState) 62 } 63 64 func (c *client) callback(ctx *gin.Context) (string, error) { 65 host, _, err := net.SplitHostPort(ctx.Request.Host) 66 if err != nil { 67 host = ctx.Request.Host 68 } 69 loginPage := "/oversikt" 70 71 redirectURI, _ := ctx.Cookie(RedirectURICookie) 72 if redirectURI != "" { 73 loginPage = redirectURI 74 } 75 76 if strings.HasPrefix(ctx.Request.Host, "localhost") { 77 loginPage = "http://localhost:8080" + loginPage 78 } 79 80 deleteCookie(ctx, RedirectURICookie, host) 81 code := ctx.Request.URL.Query().Get("code") 82 if len(code) == 0 { 83 return loginPage + "?error=unauthenticated", errors.New("unauthenticated") 84 } 85 86 oauthCookie, err := ctx.Cookie(OAuthStateCookie) 87 if err != nil { 88 c.log.Infof("Missing oauth state cookie: %v", err) 89 return loginPage + "?error=invalid-state", errors.New("invalid state") 90 } 91 92 deleteCookie(ctx, OAuthStateCookie, host) 93 94 state := ctx.Request.URL.Query().Get("state") 95 if state != oauthCookie { 96 c.log.Info("Incoming state does not match local state") 97 return loginPage + "?error=invalid-state", errors.New("invalid state") 98 } 99 100 tokens, err := c.azureClient.Exchange(ctx.Request.Context(), code) 101 if err != nil { 102 if !errors.Is(err, context.Canceled) { 103 c.log.Errorf("Exchanging authorization code for tokens: %v", err) 104 } 105 return loginPage + "?error=invalid-state", errors.New("forbidden") 106 } 107 108 rawIDToken, ok := tokens.Extra("id_token").(string) 109 if !ok { 110 c.log.Info("Missing id_token") 111 return loginPage + "?error=unauthenticated", errors.New("unauthenticated") 112 } 113 114 // Parse and verify ID Token payload. 115 _, err = c.azureClient.Verify(ctx.Request.Context(), rawIDToken) 116 if err != nil { 117 c.log.Info("Invalid id_token") 118 return loginPage + "?error=unauthenticated", errors.New("unauthenticated") 119 } 120 121 session := &auth.Session{ 122 Token: generateSecureToken(tokenLength), 123 Expires: time.Now().Add(sessionLength), 124 AccessToken: tokens.AccessToken, 125 } 126 127 b, err := base64.RawStdEncoding.DecodeString(strings.Split(tokens.AccessToken, ".")[1]) 128 if err != nil { 129 c.log.WithError(err).Error("unable decode access token") 130 return loginPage + "?error=unauthenticated", errors.New("unauthenticated") 131 } 132 133 if err := json.Unmarshal(b, session); err != nil { 134 c.log.WithError(err).Error("unable unmarshalling token") 135 return loginPage + "?error=unauthenticated", errors.New("unauthenticated") 136 } 137 138 session.IsAdmin = c.isUserInAdminGroup(session.AccessToken) 139 140 if err := c.repo.SessionCreate(ctx, session); err != nil { 141 c.log.WithError(err).Error("unable to create session") 142 return loginPage + "?error=internal-server-error", errors.New("unable to create session") 143 } 144 145 ctx.SetCookie( 146 sessionCookie, 147 session.Token, 148 86400, 149 "/", 150 host, 151 true, 152 true, 153 ) 154 155 return loginPage, nil 156 } 157 158 func (c *client) logout(ctx *gin.Context) (string, error) { 159 host, _, err := net.SplitHostPort(ctx.Request.Host) 160 if err != nil { 161 host = ctx.Request.Host 162 } 163 164 deleteCookie(ctx, sessionCookie, host) 165 166 var loginPage string 167 if strings.HasPrefix(ctx.Request.Host, "localhost") { 168 loginPage = "http://localhost:8080/" 169 } else { 170 loginPage = "/" 171 } 172 173 err = c.repo.SessionDelete(ctx, sessionCookie) 174 if err != nil { 175 c.log.WithError(err).Error("failed deleting session") 176 return loginPage, err 177 } 178 179 return loginPage, nil 180 } 181 182 func generateSecureToken(length int) string { 183 b := make([]byte, length) 184 if _, err := rand.Read(b); err != nil { 185 return "" 186 } 187 return hex.EncodeToString(b) 188 } 189 190 func deleteCookie(ctx *gin.Context, name, host string) { 191 ctx.SetCookie( 192 name, 193 "", 194 time.Unix(0, 0).Second(), 195 "/", 196 host, 197 true, 198 true, 199 ) 200 } 201 202 func (c *client) authMiddleware() gin.HandlerFunc { 203 if c.dryRun { 204 return func(ctx *gin.Context) { 205 user := &auth.User{ 206 Name: "Dum My", 207 Email: "dummy@nav.no", 208 Expires: time.Time{}, 209 } 210 ctx.Set("user", user) 211 ctx.Next() 212 } 213 } 214 215 certificates, err := c.azureClient.FetchCertificates() 216 if err != nil { 217 c.log.Fatalf("Fetching signing certificates from IdP: %v", err) 218 } 219 220 return func(ctx *gin.Context) { 221 sessionToken, err := ctx.Cookie(sessionCookie) 222 if err != nil { 223 ctx.Redirect(http.StatusSeeOther, "/oauth2/login") 224 return 225 } 226 227 session, err := c.repo.SessionGet(ctx, sessionToken) 228 if err != nil { 229 if errors.Is(err, sql.ErrNoRows) { 230 ctx.Redirect(http.StatusSeeOther, "/oauth2/login") 231 return 232 } 233 ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) 234 return 235 } 236 237 user, err := c.azureClient.ValidateUser(certificates, session.AccessToken) 238 if err != nil { 239 if errors.Is(err, auth.ErrAzureTokenExpired) { 240 ctx.Redirect(http.StatusSeeOther, "/oauth2/login") 241 return 242 } 243 ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized validate user"}) 244 return 245 } 246 247 teamSlug := ctx.Param("slug") 248 if teamSlug != "" { 249 team, err := c.repo.TeamBySlugGet(ctx, teamSlug) 250 if err != nil { 251 c.log.WithError(err).Errorf("problem checking for authorization %v", user.Email) 252 ctx.Redirect(http.StatusSeeOther, "/") 253 return 254 } 255 256 if !slices.Contains(team.Users, strings.ToLower(user.Email)) { 257 sess := sessions.Default(ctx) 258 sess.AddFlash(fmt.Sprintf("%v is not authorized", user.Email)) 259 err = sess.Save() 260 if err != nil { 261 c.log.WithError(err).Error("problem saving session") 262 ctx.Redirect(http.StatusSeeOther, "/") 263 return 264 } 265 ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("%v is not part of team %v", user.Email, teamSlug)}) 266 return 267 } 268 } 269 270 ctx.Set("user", user) 271 ctx.Next() 272 } 273 } 274 275 func (c *client) adminAuthMiddleware() gin.HandlerFunc { 276 if c.dryRun { 277 return func(ctx *gin.Context) { 278 user := &auth.User{ 279 Name: "Dum My", 280 Email: "dummy@nav.no", 281 Expires: time.Time{}, 282 } 283 ctx.Set("user", user) 284 ctx.Next() 285 } 286 } 287 return func(ctx *gin.Context) { 288 if !c.isAdmin(ctx) { 289 ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) 290 } 291 292 ctx.Next() 293 } 294 } 295 296 func (c *client) isUserInAdminGroup(token string) bool { 297 var claims jwt.MapClaims 298 299 certificates, err := c.azureClient.FetchCertificates() 300 if err != nil { 301 c.log.WithError(err).Error("fetch certificates") 302 return false 303 } 304 305 jwtValidator := auth.JWTValidator(certificates, c.azureClient.ClientID) 306 307 _, err = jwt.ParseWithClaims(token, &claims, jwtValidator) 308 309 if err != nil { 310 c.log.WithError(err).Error("Parse token") 311 return false 312 } 313 314 if claims["groups"] != nil { 315 groups, ok := claims["groups"].([]interface{}) 316 if !ok { 317 c.log.Logger.Error("User does not have groups in claims") 318 return false 319 } 320 for _, group := range groups { 321 grp, ok := group.(string) 322 if ok { 323 if grp == c.adminGroupID { 324 return true 325 } 326 } 327 } 328 } 329 return false 330 } 331 332 func (c *client) setupAuthRoutes() { 333 c.router.GET("/oauth2/login", func(ctx *gin.Context) { 334 if c.dryRun { 335 if err := c.createDryRunSession(ctx); err != nil { 336 c.log.Error("creating dryrun session") 337 } 338 ctx.Redirect(http.StatusSeeOther, "http://localhost:8080/oversikt") 339 return 340 } 341 342 consentURL := c.login(ctx) 343 ctx.Redirect(http.StatusSeeOther, consentURL) 344 }) 345 346 c.router.GET("/oauth2/callback", func(ctx *gin.Context) { 347 redirectURL, err := c.callback(ctx) 348 if err != nil { 349 session := sessions.Default(ctx) 350 session.AddFlash(err.Error()) 351 err := session.Save() 352 if err != nil { 353 c.log.WithError(err).Error("problem saving session") 354 ctx.Redirect(http.StatusSeeOther, "/") 355 return 356 } 357 ctx.Redirect(http.StatusSeeOther, "/") 358 return 359 } 360 361 ctx.Redirect(http.StatusSeeOther, redirectURL) 362 }) 363 364 c.router.GET("/oauth2/logout", func(ctx *gin.Context) { 365 redirectURL, err := c.logout(ctx) 366 if err != nil { 367 session := sessions.Default(ctx) 368 session.AddFlash(err.Error()) 369 err := session.Save() 370 if err != nil { 371 c.log.WithError(err).Error("problem saving session") 372 ctx.Redirect(http.StatusSeeOther, "/") 373 return 374 } 375 ctx.Redirect(http.StatusSeeOther, "/") 376 return 377 } 378 ctx.Redirect(http.StatusSeeOther, redirectURL) 379 }) 380 } 381 382 func (c *client) createDryRunSession(ctx *gin.Context) error { 383 session := &auth.Session{ 384 Token: generateSecureToken(tokenLength), 385 Expires: time.Now().Add(sessionLength), 386 AccessToken: "", 387 IsAdmin: true, 388 } 389 390 if err := c.repo.SessionCreate(ctx, session); err != nil { 391 c.log.WithError(err).Error("unable to create session") 392 return errors.New("unable to create session") 393 } 394 395 ctx.SetCookie( 396 sessionCookie, 397 session.Token, 398 86400, 399 "/", 400 "localhost", 401 true, 402 true, 403 ) 404 405 return nil 406 } 407 408 func getUser(ctx *gin.Context) (*auth.User, error) { 409 var user *auth.User 410 anyUser, exists := ctx.Get("user") 411 if !exists { 412 return nil, fmt.Errorf("can't verify user") 413 } 414 user = anyUser.(*auth.User) 415 416 return user, nil 417 } 418 419 func getNormalizedNameFromEmail(name string) string { 420 name = strings.Split(name, "@")[0] 421 name = strings.ReplaceAll(name, ".", "-") 422 return strings.ToLower(name) 423 }