github.com/navikt/knorten@v0.0.0-20240419132333-1333f46ed8b6/pkg/api/service/auth_service.go (about) 1 package service 2 3 import ( 4 "context" 5 "encoding/base64" 6 "encoding/json" 7 "fmt" 8 "strings" 9 "time" 10 11 "github.com/navikt/knorten/pkg/common" 12 13 "github.com/navikt/knorten/pkg/database" 14 15 "github.com/golang-jwt/jwt/v4" 16 "github.com/navikt/knorten/pkg/api/auth" 17 ) 18 19 type AuthService interface { 20 GetLoginURL(state string) string 21 CreateSession(ctx context.Context, code string) (*auth.Session, error) 22 DeleteSession(ctx context.Context, token string) error 23 } 24 25 type authService struct { 26 azureClient *auth.Azure 27 tokenLength int 28 sessionLength time.Duration 29 adminGroupID string 30 repo *database.Repo 31 } 32 33 func (s *authService) DeleteSession(ctx context.Context, token string) error { 34 err := s.repo.SessionDelete(ctx, token) 35 if err != nil { 36 return fmt.Errorf("delete session: %w", err) 37 } 38 39 return nil 40 } 41 42 func (s *authService) CreateSession(ctx context.Context, code string) (*auth.Session, error) { 43 tokens, err := s.azureClient.Exchange(ctx, code) 44 if err != nil { 45 return nil, fmt.Errorf("exchange authorization code for tokens: %w", err) 46 } 47 48 rawIDToken, ok := tokens.Extra("id_token").(string) 49 if !ok { 50 return nil, fmt.Errorf("missing id_token") 51 } 52 53 // Parse and verify ID Token payload. 54 _, err = s.azureClient.Verify(ctx, rawIDToken) 55 if err != nil { 56 return nil, fmt.Errorf("verify ID token: %w", err) 57 } 58 59 sess := &auth.Session{ 60 Token: common.GenerateSecureToken(s.tokenLength), 61 Expires: time.Now().Add(s.sessionLength), 62 AccessToken: tokens.AccessToken, 63 } 64 65 b, err := base64.RawStdEncoding.DecodeString(strings.Split(tokens.AccessToken, ".")[1]) 66 if err != nil { 67 return nil, fmt.Errorf("decode access token: %w", err) 68 } 69 70 if err := json.Unmarshal(b, sess); err != nil { 71 return nil, fmt.Errorf("unmarshal access token: %w", err) 72 } 73 74 sess.IsAdmin, err = s.isUserInAdminGroup(sess.AccessToken) 75 if err != nil { 76 return nil, fmt.Errorf("check if user is in admin group: %w", err) 77 } 78 79 err = s.repo.SessionCreate(ctx, sess) 80 if err != nil { 81 return nil, fmt.Errorf("create session: %w", err) 82 } 83 84 return sess, nil 85 } 86 87 func (s *authService) isUserInAdminGroup(token string) (bool, error) { 88 var claims jwt.MapClaims 89 90 certificates, err := s.azureClient.FetchCertificates() 91 if err != nil { 92 return false, fmt.Errorf("fetch certificates: %w", err) 93 } 94 95 jwtValidator := auth.JWTValidator(certificates, s.azureClient.ClientID) 96 97 _, err = jwt.ParseWithClaims(token, &claims, jwtValidator) 98 if err != nil { 99 return false, fmt.Errorf("parse claims: %w", err) 100 } 101 102 if claims["groups"] == nil { 103 return false, nil 104 } 105 106 groups, ok := claims["groups"].([]interface{}) 107 if !ok { 108 return false, nil 109 } 110 111 for _, group := range groups { 112 if grp, ok := group.(string); ok { 113 if grp == s.adminGroupID { 114 return true, nil 115 } 116 } 117 } 118 119 return false, nil 120 } 121 122 func (s *authService) GetLoginURL(state string) string { 123 return s.azureClient.AuthCodeURL(state) 124 } 125 126 func NewAuthService(repo *database.Repo, adminGroupID string, sessionLength time.Duration, tokenLength int, azureClient *auth.Azure) *authService { 127 return &authService{ 128 azureClient: azureClient, 129 tokenLength: tokenLength, 130 sessionLength: sessionLength, 131 adminGroupID: adminGroupID, 132 repo: repo, 133 } 134 }