github.com/navikt/knorten@v0.0.0-20240419132333-1333f46ed8b6/pkg/api/service/auth_service.go (about)

     1  package service
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"fmt"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/navikt/knorten/pkg/common"
    12  
    13  	"github.com/navikt/knorten/pkg/database"
    14  
    15  	"github.com/golang-jwt/jwt/v4"
    16  	"github.com/navikt/knorten/pkg/api/auth"
    17  )
    18  
    19  type AuthService interface {
    20  	GetLoginURL(state string) string
    21  	CreateSession(ctx context.Context, code string) (*auth.Session, error)
    22  	DeleteSession(ctx context.Context, token string) error
    23  }
    24  
    25  type authService struct {
    26  	azureClient   *auth.Azure
    27  	tokenLength   int
    28  	sessionLength time.Duration
    29  	adminGroupID  string
    30  	repo          *database.Repo
    31  }
    32  
    33  func (s *authService) DeleteSession(ctx context.Context, token string) error {
    34  	err := s.repo.SessionDelete(ctx, token)
    35  	if err != nil {
    36  		return fmt.Errorf("delete session: %w", err)
    37  	}
    38  
    39  	return nil
    40  }
    41  
    42  func (s *authService) CreateSession(ctx context.Context, code string) (*auth.Session, error) {
    43  	tokens, err := s.azureClient.Exchange(ctx, code)
    44  	if err != nil {
    45  		return nil, fmt.Errorf("exchange authorization code for tokens: %w", err)
    46  	}
    47  
    48  	rawIDToken, ok := tokens.Extra("id_token").(string)
    49  	if !ok {
    50  		return nil, fmt.Errorf("missing id_token")
    51  	}
    52  
    53  	// Parse and verify ID Token payload.
    54  	_, err = s.azureClient.Verify(ctx, rawIDToken)
    55  	if err != nil {
    56  		return nil, fmt.Errorf("verify ID token: %w", err)
    57  	}
    58  
    59  	sess := &auth.Session{
    60  		Token:       common.GenerateSecureToken(s.tokenLength),
    61  		Expires:     time.Now().Add(s.sessionLength),
    62  		AccessToken: tokens.AccessToken,
    63  	}
    64  
    65  	b, err := base64.RawStdEncoding.DecodeString(strings.Split(tokens.AccessToken, ".")[1])
    66  	if err != nil {
    67  		return nil, fmt.Errorf("decode access token: %w", err)
    68  	}
    69  
    70  	if err := json.Unmarshal(b, sess); err != nil {
    71  		return nil, fmt.Errorf("unmarshal access token: %w", err)
    72  	}
    73  
    74  	sess.IsAdmin, err = s.isUserInAdminGroup(sess.AccessToken)
    75  	if err != nil {
    76  		return nil, fmt.Errorf("check if user is in admin group: %w", err)
    77  	}
    78  
    79  	err = s.repo.SessionCreate(ctx, sess)
    80  	if err != nil {
    81  		return nil, fmt.Errorf("create session: %w", err)
    82  	}
    83  
    84  	return sess, nil
    85  }
    86  
    87  func (s *authService) isUserInAdminGroup(token string) (bool, error) {
    88  	var claims jwt.MapClaims
    89  
    90  	certificates, err := s.azureClient.FetchCertificates()
    91  	if err != nil {
    92  		return false, fmt.Errorf("fetch certificates: %w", err)
    93  	}
    94  
    95  	jwtValidator := auth.JWTValidator(certificates, s.azureClient.ClientID)
    96  
    97  	_, err = jwt.ParseWithClaims(token, &claims, jwtValidator)
    98  	if err != nil {
    99  		return false, fmt.Errorf("parse claims: %w", err)
   100  	}
   101  
   102  	if claims["groups"] == nil {
   103  		return false, nil
   104  	}
   105  
   106  	groups, ok := claims["groups"].([]interface{})
   107  	if !ok {
   108  		return false, nil
   109  	}
   110  
   111  	for _, group := range groups {
   112  		if grp, ok := group.(string); ok {
   113  			if grp == s.adminGroupID {
   114  				return true, nil
   115  			}
   116  		}
   117  	}
   118  
   119  	return false, nil
   120  }
   121  
   122  func (s *authService) GetLoginURL(state string) string {
   123  	return s.azureClient.AuthCodeURL(state)
   124  }
   125  
   126  func NewAuthService(repo *database.Repo, adminGroupID string, sessionLength time.Duration, tokenLength int, azureClient *auth.Azure) *authService {
   127  	return &authService{
   128  		azureClient:   azureClient,
   129  		tokenLength:   tokenLength,
   130  		sessionLength: sessionLength,
   131  		adminGroupID:  adminGroupID,
   132  		repo:          repo,
   133  	}
   134  }