github.com/navikt/knorten@v0.0.0-20240419132333-1333f46ed8b6/pkg/api/auth/auth.go (about)

     1  package auth
     2  
     3  import (
     4  	"context"
     5  	"crypto/x509"
     6  	"encoding/base64"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"net/url"
    13  	"os"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/coreos/go-oidc"
    18  	"github.com/golang-jwt/jwt/v4"
    19  	"github.com/sirupsen/logrus"
    20  	"golang.org/x/exp/rand"
    21  	"golang.org/x/oauth2"
    22  	"golang.org/x/oauth2/endpoints"
    23  )
    24  
    25  type OauthConfig struct {
    26  	ClientID     string
    27  	ClientSecret string
    28  	TenantID     string
    29  }
    30  
    31  type Session struct {
    32  	Email       string `json:"preferred_username"`
    33  	Name        string `json:"name"`
    34  	AccessToken string
    35  	Token       string
    36  	Expires     time.Time
    37  	IsAdmin     bool
    38  }
    39  
    40  type Azure struct {
    41  	oauth2.Config
    42  
    43  	clientID     string
    44  	clientSecret string
    45  	tenantID     string
    46  	dryRun       bool
    47  	provider     *oidc.Provider
    48  	log          *logrus.Entry
    49  }
    50  
    51  type User struct {
    52  	Name    string
    53  	Email   string
    54  	Expires time.Time
    55  }
    56  
    57  type AzureGroupsWithIDResponse struct {
    58  	Groups []AzureGroupWithID `json:"value"`
    59  }
    60  
    61  type AzureGroupWithID struct {
    62  	DisplayName string `json:"displayName"`
    63  	ID          string `json:"id"`
    64  	Mail        string `json:"mail"`
    65  }
    66  
    67  type TokenResponse struct {
    68  	AccessToken string `json:"access_token"`
    69  }
    70  
    71  var ErrAzureTokenExpired = fmt.Errorf("token expired")
    72  
    73  const (
    74  	AzureUsersEndpoint  = "https://graph.microsoft.com/v1.0/users"
    75  	AzureGroupsEndpoint = "https://graph.microsoft.com/v1.0/groups"
    76  )
    77  
    78  func NewAzureClient(dryRun bool, clientID, clientSecret, tenantID string, log *logrus.Entry) (*Azure, error) {
    79  	if dryRun {
    80  		log.Infof("NOOP: Running in dry run mode")
    81  		return &Azure{
    82  			dryRun: dryRun,
    83  			log:    log,
    84  		}, nil
    85  	}
    86  
    87  	provider, err := oidc.NewProvider(context.Background(), fmt.Sprintf("https://login.microsoftonline.com/%v/v2.0", tenantID))
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	a := &Azure{
    93  		clientID:     clientID,
    94  		clientSecret: clientSecret,
    95  		tenantID:     tenantID,
    96  		provider:     provider,
    97  		dryRun:       dryRun,
    98  		log:          log,
    99  	}
   100  
   101  	a.setupOAuth2()
   102  	return a, nil
   103  }
   104  
   105  func (a *Azure) setupOAuth2() {
   106  	redirectURL := "https://knorten.knada.io/oauth2/callback"
   107  	if os.Getenv("GIN_MODE") != "release" {
   108  		redirectURL = "http://localhost:8080/oauth2/callback"
   109  	}
   110  
   111  	a.Config = oauth2.Config{
   112  		ClientID:     a.clientID,
   113  		ClientSecret: a.clientSecret,
   114  		Endpoint:     a.provider.Endpoint(),
   115  		RedirectURL:  redirectURL,
   116  		Scopes:       []string{"openid", fmt.Sprintf("%s/.default", a.clientID)},
   117  	}
   118  }
   119  
   120  func (a *Azure) KeyDiscoveryURL() string {
   121  	return fmt.Sprintf("https://login.microsoftonline.com/%s/discovery/v2.0/keys", a.tenantID)
   122  }
   123  
   124  func (a *Azure) Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error) {
   125  	return a.provider.Verifier(&oidc.Config{ClientID: a.clientID}).Verify(ctx, rawIDToken)
   126  }
   127  
   128  func (a *Azure) FetchCertificates() (map[string]CertificateList, error) {
   129  	discoveryURL := a.KeyDiscoveryURL()
   130  	azureKeyDiscovery, err := DiscoverURL(discoveryURL)
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  
   135  	azureCertificates, err := azureKeyDiscovery.Map()
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	return azureCertificates, nil
   140  }
   141  
   142  func (a *Azure) ValidateUser(certificates map[string]CertificateList, token string) (*User, error) {
   143  	var claims jwt.MapClaims
   144  
   145  	jwtValidator := JWTValidator(certificates, a.clientID)
   146  
   147  	azureToken, err := jwt.ParseWithClaims(token, &claims, jwtValidator)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	if !azureToken.Valid {
   152  		return nil, ErrAzureTokenExpired
   153  	}
   154  
   155  	return &User{
   156  		Name:    claims["name"].(string),
   157  		Email:   strings.ToLower(claims["preferred_username"].(string)),
   158  		Expires: time.Unix(int64(claims["exp"].(float64)), 0),
   159  	}, nil
   160  }
   161  
   162  func (a *Azure) UserExistsInAzureAD(user string) error {
   163  	if a.dryRun {
   164  		fmt.Printf("NOOP: Would have checked if user %v exists in Azure AD\n", user)
   165  		return nil
   166  	}
   167  
   168  	type usersResponse struct {
   169  		Value []struct {
   170  			Email string `json:"userPrincipalName"`
   171  		} `json:"value"`
   172  	}
   173  
   174  	token, err := a.getBearerTokenForApplication()
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	r, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v?$filter=startswith(userPrincipalName,'%v')", AzureUsersEndpoint, user), nil)
   180  	if err != nil {
   181  		return err
   182  	}
   183  	r.Header.Add("Authorization", fmt.Sprintf("Bearer %v", token))
   184  
   185  	httpClient := &http.Client{
   186  		Timeout: time.Second * 10,
   187  	}
   188  
   189  	res, err := httpClient.Do(r)
   190  	if err != nil {
   191  		return err
   192  	}
   193  
   194  	resBytes, err := io.ReadAll(res.Body)
   195  	if err != nil {
   196  		return err
   197  	}
   198  
   199  	var users usersResponse
   200  	if err := json.Unmarshal(resBytes, &users); err != nil {
   201  		return err
   202  	}
   203  
   204  	switch len(users.Value) {
   205  	case 0:
   206  		return fmt.Errorf("no user exists in aad with email %v", user)
   207  	case 1:
   208  		return nil
   209  	default:
   210  		return fmt.Errorf("multiple users exist in aad for email %v", user)
   211  	}
   212  }
   213  
   214  func (a *Azure) ConvertEmailsToIdents(emails []string) ([]string, error) {
   215  	var idents []string
   216  	for _, e := range emails {
   217  		ident, err := a.identForEmail(e)
   218  		if err != nil {
   219  			return nil, err
   220  		}
   221  		if ident != "" {
   222  			idents = append(idents, ident)
   223  		}
   224  	}
   225  	return idents, nil
   226  }
   227  
   228  func (a *Azure) identForEmail(email string) (string, error) {
   229  	if a.dryRun {
   230  		a.log.Infof("NOOP: Running in dry run mode")
   231  		return fmt.Sprintf("d%v", rand.Intn(10000)+100000), nil
   232  	}
   233  
   234  	type identResponse struct {
   235  		Ident string `json:"onPremisesSamAccountName"`
   236  	}
   237  
   238  	token, err := a.getBearerTokenForApplication()
   239  	if err != nil {
   240  		return "", err
   241  	}
   242  
   243  	r, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v/%v?$select=onPremisesSamAccountName", AzureUsersEndpoint, email), nil)
   244  	if err != nil {
   245  		return "", err
   246  	}
   247  	r.Header.Add("Authorization", fmt.Sprintf("Bearer %v", token))
   248  
   249  	httpClient := &http.Client{
   250  		Timeout: time.Second * 10,
   251  	}
   252  
   253  	res, err := httpClient.Do(r)
   254  	if err != nil {
   255  		return "", err
   256  	}
   257  
   258  	resBytes, err := io.ReadAll(res.Body)
   259  	if err != nil {
   260  		return "", err
   261  	}
   262  
   263  	var identRes identResponse
   264  	if err := json.Unmarshal(resBytes, &identRes); err != nil {
   265  		return "", err
   266  	}
   267  
   268  	if identRes.Ident == "" {
   269  		a.log.Errorf("unable to get user ident for email %v", email)
   270  	}
   271  
   272  	return strings.ToLower(identRes.Ident), nil
   273  }
   274  
   275  func (a *Azure) getBearerTokenForApplication() (string, error) {
   276  	form := url.Values{}
   277  	form.Add("client_id", a.clientID)
   278  	form.Add("client_secret", a.clientSecret)
   279  	form.Add("scope", "https://graph.microsoft.com/.default")
   280  	form.Add("grant_type", "client_credentials")
   281  
   282  	req, err := http.NewRequest(http.MethodPost, endpoints.AzureAD(a.tenantID).TokenURL, strings.NewReader(form.Encode()))
   283  	if err != nil {
   284  		return "", err
   285  	}
   286  
   287  	httpClient := &http.Client{
   288  		Timeout: time.Second * 10,
   289  	}
   290  
   291  	response, err := httpClient.Do(req)
   292  	if err != nil {
   293  		return "", err
   294  	}
   295  
   296  	var tokenResponse TokenResponse
   297  	if err := json.NewDecoder(response.Body).Decode(&tokenResponse); err != nil {
   298  		return "", err
   299  	}
   300  
   301  	return tokenResponse.AccessToken, nil
   302  }
   303  
   304  func (a *Azure) GetGroupID(groupMail string) (string, error) {
   305  	if a.dryRun {
   306  		a.log.Infof("NOOP: Running in dry run mode")
   307  		return "dummyID", nil
   308  	}
   309  
   310  	token, err := a.getBearerTokenForApplication()
   311  	if err != nil {
   312  		return "", err
   313  	}
   314  
   315  	params := url.Values{}
   316  	params.Add("$select", "id,displayName,mail")
   317  	params.Add("$filter", fmt.Sprintf("mail eq '%v'", groupMail))
   318  
   319  	req, err := http.NewRequest(http.MethodGet,
   320  		AzureGroupsEndpoint+"?"+params.Encode(),
   321  		nil)
   322  	if err != nil {
   323  		return "", err
   324  	}
   325  	req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", token))
   326  
   327  	httpClient := &http.Client{
   328  		Timeout: time.Second * 10,
   329  	}
   330  
   331  	response, err := httpClient.Do(req)
   332  	if err != nil {
   333  		return "", err
   334  	}
   335  
   336  	var groupsResponse AzureGroupsWithIDResponse
   337  	if err := json.NewDecoder(response.Body).Decode(&groupsResponse); err != nil {
   338  		return "", err
   339  	}
   340  
   341  	if len(groupsResponse.Groups) > 0 {
   342  		return groupsResponse.Groups[0].ID, nil
   343  	} else {
   344  		return "", errors.New("group not found by the mail")
   345  	}
   346  }
   347  
   348  type CertificateList []*x509.Certificate
   349  
   350  type KeyDiscovery struct {
   351  	Keys []Key `json:"keys"`
   352  }
   353  
   354  type EncodedCertificate string
   355  
   356  type Key struct {
   357  	Kid string               `json:"kid"`
   358  	X5c []EncodedCertificate `json:"x5c"`
   359  }
   360  
   361  // Map transform a KeyDiscovery object into a dictionary with "kid" as key
   362  // and lists of decoded X509 certificates as values.
   363  //
   364  // Returns an error if any certificate does not decode.
   365  func (k *KeyDiscovery) Map() (result map[string]CertificateList, err error) {
   366  	result = make(map[string]CertificateList)
   367  
   368  	for _, key := range k.Keys {
   369  		certList := make(CertificateList, 0)
   370  		for _, encodedCertificate := range key.X5c {
   371  			certificate, err := encodedCertificate.Decode()
   372  			if err != nil {
   373  				return nil, err
   374  			}
   375  			certList = append(certList, certificate)
   376  		}
   377  		result[key.Kid] = certList
   378  	}
   379  
   380  	return
   381  }
   382  
   383  // Decode a base64 encoded certificate into a X509 structure.
   384  func (c EncodedCertificate) Decode() (*x509.Certificate, error) {
   385  	stream := strings.NewReader(string(c))
   386  	decoder := base64.NewDecoder(base64.StdEncoding, stream)
   387  	key, err := io.ReadAll(decoder)
   388  	if err != nil {
   389  		return nil, err
   390  	}
   391  
   392  	return x509.ParseCertificate(key)
   393  }
   394  
   395  func DiscoverURL(url string) (*KeyDiscovery, error) {
   396  	response, err := http.Get(url)
   397  	if err != nil {
   398  		return nil, err
   399  	}
   400  
   401  	return Discover(response.Body)
   402  }
   403  
   404  func Discover(reader io.Reader) (*KeyDiscovery, error) {
   405  	document, err := io.ReadAll(reader)
   406  	if err != nil {
   407  		return nil, err
   408  	}
   409  
   410  	keyDiscovery := &KeyDiscovery{}
   411  	err = json.Unmarshal(document, keyDiscovery)
   412  
   413  	return keyDiscovery, err
   414  }
   415  
   416  func JWTValidator(certificates map[string]CertificateList, audience string) jwt.Keyfunc {
   417  	return func(token *jwt.Token) (interface{}, error) {
   418  		var certificateList CertificateList
   419  		var kid string
   420  		var ok bool
   421  
   422  		if claims, ok := token.Claims.(*jwt.MapClaims); !ok {
   423  			return nil, fmt.Errorf("unable to retrieve claims from token")
   424  		} else {
   425  			if valid := claims.VerifyAudience(audience, true); !valid {
   426  				return nil, fmt.Errorf("the token is not valid for this application")
   427  			}
   428  		}
   429  
   430  		if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
   431  			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
   432  		}
   433  
   434  		if kid, ok = token.Header["kid"].(string); !ok {
   435  			return nil, fmt.Errorf("field 'kid' is of invalid type %T, should be string", token.Header["kid"])
   436  		}
   437  
   438  		if certificateList, ok = certificates[kid]; !ok {
   439  			return nil, fmt.Errorf("kid '%s' not found in certificate list", kid)
   440  		}
   441  
   442  		for _, certificate := range certificateList {
   443  			return certificate.PublicKey, nil
   444  		}
   445  
   446  		return nil, fmt.Errorf("no certificate candidates for kid '%s'", kid)
   447  	}
   448  }