github.com/nais/knorten@v0.0.0-20240104110906-55926958e361/pkg/api/auth.go (about)

     1  package api
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"database/sql"
     7  	"encoding/base64"
     8  	"encoding/hex"
     9  	"encoding/json"
    10  	"errors"
    11  	"fmt"
    12  	"net"
    13  	"net/http"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/gin-contrib/sessions"
    18  	"github.com/gin-gonic/gin"
    19  	"github.com/golang-jwt/jwt/v4"
    20  	"github.com/google/uuid"
    21  	"github.com/nais/knorten/pkg/api/auth"
    22  	"k8s.io/utils/strings/slices"
    23  )
    24  
    25  const (
    26  	RedirectURICookie = "redirecturi"
    27  	OAuthStateCookie  = "oauthstate"
    28  	sessionCookie     = "knorten_session"
    29  	tokenLength       = 32
    30  	sessionLength     = 1 * time.Hour
    31  )
    32  
    33  func (c *client) login(ctx *gin.Context) string {
    34  	host, _, err := net.SplitHostPort(ctx.Request.Host)
    35  	if err != nil {
    36  		host = ctx.Request.Host
    37  	}
    38  
    39  	redirectURI := ctx.Request.URL.Query().Get("redirect_uri")
    40  	ctx.SetCookie(
    41  		RedirectURICookie,
    42  		redirectURI,
    43  		time.Now().Add(30*time.Minute).Second(),
    44  		"/",
    45  		host,
    46  		true,
    47  		true,
    48  	)
    49  
    50  	oauthState := uuid.New().String()
    51  	ctx.SetCookie(
    52  		OAuthStateCookie,
    53  		oauthState,
    54  		time.Now().Add(30*time.Minute).Second(),
    55  		"/",
    56  		host,
    57  		true,
    58  		true,
    59  	)
    60  
    61  	return c.azureClient.AuthCodeURL(oauthState)
    62  }
    63  
    64  func (c *client) callback(ctx *gin.Context) (string, error) {
    65  	host, _, err := net.SplitHostPort(ctx.Request.Host)
    66  	if err != nil {
    67  		host = ctx.Request.Host
    68  	}
    69  	loginPage := "/oversikt"
    70  
    71  	redirectURI, _ := ctx.Cookie(RedirectURICookie)
    72  	if redirectURI != "" {
    73  		loginPage = redirectURI
    74  	}
    75  
    76  	if strings.HasPrefix(ctx.Request.Host, "localhost") {
    77  		loginPage = "http://localhost:8080" + loginPage
    78  	}
    79  
    80  	deleteCookie(ctx, RedirectURICookie, host)
    81  	code := ctx.Request.URL.Query().Get("code")
    82  	if len(code) == 0 {
    83  		return loginPage + "?error=unauthenticated", errors.New("unauthenticated")
    84  	}
    85  
    86  	oauthCookie, err := ctx.Cookie(OAuthStateCookie)
    87  	if err != nil {
    88  		c.log.Infof("Missing oauth state cookie: %v", err)
    89  		return loginPage + "?error=invalid-state", errors.New("invalid state")
    90  	}
    91  
    92  	deleteCookie(ctx, OAuthStateCookie, host)
    93  
    94  	state := ctx.Request.URL.Query().Get("state")
    95  	if state != oauthCookie {
    96  		c.log.Info("Incoming state does not match local state")
    97  		return loginPage + "?error=invalid-state", errors.New("invalid state")
    98  	}
    99  
   100  	tokens, err := c.azureClient.Exchange(ctx.Request.Context(), code)
   101  	if err != nil {
   102  		if !errors.Is(err, context.Canceled) {
   103  			c.log.Errorf("Exchanging authorization code for tokens: %v", err)
   104  		}
   105  		return loginPage + "?error=invalid-state", errors.New("forbidden")
   106  	}
   107  
   108  	rawIDToken, ok := tokens.Extra("id_token").(string)
   109  	if !ok {
   110  		c.log.Info("Missing id_token")
   111  		return loginPage + "?error=unauthenticated", errors.New("unauthenticated")
   112  	}
   113  
   114  	// Parse and verify ID Token payload.
   115  	_, err = c.azureClient.Verify(ctx.Request.Context(), rawIDToken)
   116  	if err != nil {
   117  		c.log.Info("Invalid id_token")
   118  		return loginPage + "?error=unauthenticated", errors.New("unauthenticated")
   119  	}
   120  
   121  	session := &auth.Session{
   122  		Token:       generateSecureToken(tokenLength),
   123  		Expires:     time.Now().Add(sessionLength),
   124  		AccessToken: tokens.AccessToken,
   125  	}
   126  
   127  	b, err := base64.RawStdEncoding.DecodeString(strings.Split(tokens.AccessToken, ".")[1])
   128  	if err != nil {
   129  		c.log.WithError(err).Error("unable decode access token")
   130  		return loginPage + "?error=unauthenticated", errors.New("unauthenticated")
   131  	}
   132  
   133  	if err := json.Unmarshal(b, session); err != nil {
   134  		c.log.WithError(err).Error("unable unmarshalling token")
   135  		return loginPage + "?error=unauthenticated", errors.New("unauthenticated")
   136  	}
   137  
   138  	session.IsAdmin = c.isUserInAdminGroup(session.AccessToken)
   139  
   140  	if err := c.repo.SessionCreate(ctx, session); err != nil {
   141  		c.log.WithError(err).Error("unable to create session")
   142  		return loginPage + "?error=internal-server-error", errors.New("unable to create session")
   143  	}
   144  
   145  	ctx.SetCookie(
   146  		sessionCookie,
   147  		session.Token,
   148  		86400,
   149  		"/",
   150  		host,
   151  		true,
   152  		true,
   153  	)
   154  
   155  	return loginPage, nil
   156  }
   157  
   158  func (c *client) logout(ctx *gin.Context) (string, error) {
   159  	host, _, err := net.SplitHostPort(ctx.Request.Host)
   160  	if err != nil {
   161  		host = ctx.Request.Host
   162  	}
   163  
   164  	deleteCookie(ctx, sessionCookie, host)
   165  
   166  	var loginPage string
   167  	if strings.HasPrefix(ctx.Request.Host, "localhost") {
   168  		loginPage = "http://localhost:8080/"
   169  	} else {
   170  		loginPage = "/"
   171  	}
   172  
   173  	err = c.repo.SessionDelete(ctx, sessionCookie)
   174  	if err != nil {
   175  		c.log.WithError(err).Error("failed deleting session")
   176  		return loginPage, err
   177  	}
   178  
   179  	return loginPage, nil
   180  }
   181  
   182  func generateSecureToken(length int) string {
   183  	b := make([]byte, length)
   184  	if _, err := rand.Read(b); err != nil {
   185  		return ""
   186  	}
   187  	return hex.EncodeToString(b)
   188  }
   189  
   190  func deleteCookie(ctx *gin.Context, name, host string) {
   191  	ctx.SetCookie(
   192  		name,
   193  		"",
   194  		time.Unix(0, 0).Second(),
   195  		"/",
   196  		host,
   197  		true,
   198  		true,
   199  	)
   200  }
   201  
   202  func (c *client) authMiddleware() gin.HandlerFunc {
   203  	if c.dryRun {
   204  		return func(ctx *gin.Context) {
   205  			user := &auth.User{
   206  				Name:    "Dum My",
   207  				Email:   "dummy@nav.no",
   208  				Expires: time.Time{},
   209  			}
   210  			ctx.Set("user", user)
   211  			ctx.Next()
   212  		}
   213  	}
   214  
   215  	certificates, err := c.azureClient.FetchCertificates()
   216  	if err != nil {
   217  		c.log.Fatalf("Fetching signing certificates from IdP: %v", err)
   218  	}
   219  
   220  	return func(ctx *gin.Context) {
   221  		sessionToken, err := ctx.Cookie(sessionCookie)
   222  		if err != nil {
   223  			ctx.Redirect(http.StatusSeeOther, "/oauth2/login")
   224  			return
   225  		}
   226  
   227  		session, err := c.repo.SessionGet(ctx, sessionToken)
   228  		if err != nil {
   229  			if errors.Is(err, sql.ErrNoRows) {
   230  				ctx.Redirect(http.StatusSeeOther, "/oauth2/login")
   231  				return
   232  			}
   233  			ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
   234  			return
   235  		}
   236  
   237  		user, err := c.azureClient.ValidateUser(certificates, session.AccessToken)
   238  		if err != nil {
   239  			if errors.Is(err, auth.ErrAzureTokenExpired) {
   240  				ctx.Redirect(http.StatusSeeOther, "/oauth2/login")
   241  				return
   242  			}
   243  			ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized validate user"})
   244  			return
   245  		}
   246  
   247  		teamSlug := ctx.Param("slug")
   248  		if teamSlug != "" {
   249  			team, err := c.repo.TeamBySlugGet(ctx, teamSlug)
   250  			if err != nil {
   251  				c.log.WithError(err).Errorf("problem checking for authorization %v", user.Email)
   252  				ctx.Redirect(http.StatusSeeOther, "/")
   253  				return
   254  			}
   255  
   256  			if !slices.Contains(team.Users, strings.ToLower(user.Email)) {
   257  				sess := sessions.Default(ctx)
   258  				sess.AddFlash(fmt.Sprintf("%v is not authorized", user.Email))
   259  				err = sess.Save()
   260  				if err != nil {
   261  					c.log.WithError(err).Error("problem saving session")
   262  					ctx.Redirect(http.StatusSeeOther, "/")
   263  					return
   264  				}
   265  				ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("%v is not part of team %v", user.Email, teamSlug)})
   266  				return
   267  			}
   268  		}
   269  
   270  		ctx.Set("user", user)
   271  		ctx.Next()
   272  	}
   273  }
   274  
   275  func (c *client) adminAuthMiddleware() gin.HandlerFunc {
   276  	if c.dryRun {
   277  		return func(ctx *gin.Context) {
   278  			user := &auth.User{
   279  				Name:    "Dum My",
   280  				Email:   "dummy@nav.no",
   281  				Expires: time.Time{},
   282  			}
   283  			ctx.Set("user", user)
   284  			ctx.Next()
   285  		}
   286  	}
   287  	return func(ctx *gin.Context) {
   288  		if !c.isAdmin(ctx) {
   289  			ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
   290  		}
   291  
   292  		ctx.Next()
   293  	}
   294  }
   295  
   296  func (c *client) isUserInAdminGroup(token string) bool {
   297  	var claims jwt.MapClaims
   298  
   299  	certificates, err := c.azureClient.FetchCertificates()
   300  	if err != nil {
   301  		c.log.WithError(err).Error("fetch certificates")
   302  		return false
   303  	}
   304  
   305  	jwtValidator := auth.JWTValidator(certificates, c.azureClient.ClientID)
   306  
   307  	_, err = jwt.ParseWithClaims(token, &claims, jwtValidator)
   308  
   309  	if err != nil {
   310  		c.log.WithError(err).Error("Parse token")
   311  		return false
   312  	}
   313  
   314  	if claims["groups"] != nil {
   315  		groups, ok := claims["groups"].([]interface{})
   316  		if !ok {
   317  			c.log.Logger.Error("User does not have groups in claims")
   318  			return false
   319  		}
   320  		for _, group := range groups {
   321  			grp, ok := group.(string)
   322  			if ok {
   323  				if grp == c.adminGroupID {
   324  					return true
   325  				}
   326  			}
   327  		}
   328  	}
   329  	return false
   330  }
   331  
   332  func (c *client) setupAuthRoutes() {
   333  	c.router.GET("/oauth2/login", func(ctx *gin.Context) {
   334  		if c.dryRun {
   335  			if err := c.createDryRunSession(ctx); err != nil {
   336  				c.log.Error("creating dryrun session")
   337  			}
   338  			ctx.Redirect(http.StatusSeeOther, "http://localhost:8080/oversikt")
   339  			return
   340  		}
   341  
   342  		consentURL := c.login(ctx)
   343  		ctx.Redirect(http.StatusSeeOther, consentURL)
   344  	})
   345  
   346  	c.router.GET("/oauth2/callback", func(ctx *gin.Context) {
   347  		redirectURL, err := c.callback(ctx)
   348  		if err != nil {
   349  			session := sessions.Default(ctx)
   350  			session.AddFlash(err.Error())
   351  			err := session.Save()
   352  			if err != nil {
   353  				c.log.WithError(err).Error("problem saving session")
   354  				ctx.Redirect(http.StatusSeeOther, "/")
   355  				return
   356  			}
   357  			ctx.Redirect(http.StatusSeeOther, "/")
   358  			return
   359  		}
   360  
   361  		ctx.Redirect(http.StatusSeeOther, redirectURL)
   362  	})
   363  
   364  	c.router.GET("/oauth2/logout", func(ctx *gin.Context) {
   365  		redirectURL, err := c.logout(ctx)
   366  		if err != nil {
   367  			session := sessions.Default(ctx)
   368  			session.AddFlash(err.Error())
   369  			err := session.Save()
   370  			if err != nil {
   371  				c.log.WithError(err).Error("problem saving session")
   372  				ctx.Redirect(http.StatusSeeOther, "/")
   373  				return
   374  			}
   375  			ctx.Redirect(http.StatusSeeOther, "/")
   376  			return
   377  		}
   378  		ctx.Redirect(http.StatusSeeOther, redirectURL)
   379  	})
   380  }
   381  
   382  func (c *client) createDryRunSession(ctx *gin.Context) error {
   383  	session := &auth.Session{
   384  		Token:       generateSecureToken(tokenLength),
   385  		Expires:     time.Now().Add(sessionLength),
   386  		AccessToken: "",
   387  		IsAdmin:     true,
   388  	}
   389  
   390  	if err := c.repo.SessionCreate(ctx, session); err != nil {
   391  		c.log.WithError(err).Error("unable to create session")
   392  		return errors.New("unable to create session")
   393  	}
   394  
   395  	ctx.SetCookie(
   396  		sessionCookie,
   397  		session.Token,
   398  		86400,
   399  		"/",
   400  		"localhost",
   401  		true,
   402  		true,
   403  	)
   404  
   405  	return nil
   406  }
   407  
   408  func getUser(ctx *gin.Context) (*auth.User, error) {
   409  	var user *auth.User
   410  	anyUser, exists := ctx.Get("user")
   411  	if !exists {
   412  		return nil, fmt.Errorf("can't verify user")
   413  	}
   414  	user = anyUser.(*auth.User)
   415  
   416  	return user, nil
   417  }
   418  
   419  func getNormalizedNameFromEmail(name string) string {
   420  	name = strings.Split(name, "@")[0]
   421  	name = strings.ReplaceAll(name, ".", "-")
   422  	return strings.ToLower(name)
   423  }