github.com/navikt/knorten@v0.0.0-20240419132333-1333f46ed8b6/pkg/api/auth/auth.go (about) 1 package auth 2 3 import ( 4 "context" 5 "crypto/x509" 6 "encoding/base64" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "io" 11 "net/http" 12 "net/url" 13 "os" 14 "strings" 15 "time" 16 17 "github.com/coreos/go-oidc" 18 "github.com/golang-jwt/jwt/v4" 19 "github.com/sirupsen/logrus" 20 "golang.org/x/exp/rand" 21 "golang.org/x/oauth2" 22 "golang.org/x/oauth2/endpoints" 23 ) 24 25 type OauthConfig struct { 26 ClientID string 27 ClientSecret string 28 TenantID string 29 } 30 31 type Session struct { 32 Email string `json:"preferred_username"` 33 Name string `json:"name"` 34 AccessToken string 35 Token string 36 Expires time.Time 37 IsAdmin bool 38 } 39 40 type Azure struct { 41 oauth2.Config 42 43 clientID string 44 clientSecret string 45 tenantID string 46 dryRun bool 47 provider *oidc.Provider 48 log *logrus.Entry 49 } 50 51 type User struct { 52 Name string 53 Email string 54 Expires time.Time 55 } 56 57 type AzureGroupsWithIDResponse struct { 58 Groups []AzureGroupWithID `json:"value"` 59 } 60 61 type AzureGroupWithID struct { 62 DisplayName string `json:"displayName"` 63 ID string `json:"id"` 64 Mail string `json:"mail"` 65 } 66 67 type TokenResponse struct { 68 AccessToken string `json:"access_token"` 69 } 70 71 var ErrAzureTokenExpired = fmt.Errorf("token expired") 72 73 const ( 74 AzureUsersEndpoint = "https://graph.microsoft.com/v1.0/users" 75 AzureGroupsEndpoint = "https://graph.microsoft.com/v1.0/groups" 76 ) 77 78 func NewAzureClient(dryRun bool, clientID, clientSecret, tenantID string, log *logrus.Entry) (*Azure, error) { 79 if dryRun { 80 log.Infof("NOOP: Running in dry run mode") 81 return &Azure{ 82 dryRun: dryRun, 83 log: log, 84 }, nil 85 } 86 87 provider, err := oidc.NewProvider(context.Background(), fmt.Sprintf("https://login.microsoftonline.com/%v/v2.0", tenantID)) 88 if err != nil { 89 return nil, err 90 } 91 92 a := &Azure{ 93 clientID: clientID, 94 clientSecret: clientSecret, 95 tenantID: tenantID, 96 provider: provider, 97 dryRun: dryRun, 98 log: log, 99 } 100 101 a.setupOAuth2() 102 return a, nil 103 } 104 105 func (a *Azure) setupOAuth2() { 106 redirectURL := "https://knorten.knada.io/oauth2/callback" 107 if os.Getenv("GIN_MODE") != "release" { 108 redirectURL = "http://localhost:8080/oauth2/callback" 109 } 110 111 a.Config = oauth2.Config{ 112 ClientID: a.clientID, 113 ClientSecret: a.clientSecret, 114 Endpoint: a.provider.Endpoint(), 115 RedirectURL: redirectURL, 116 Scopes: []string{"openid", fmt.Sprintf("%s/.default", a.clientID)}, 117 } 118 } 119 120 func (a *Azure) KeyDiscoveryURL() string { 121 return fmt.Sprintf("https://login.microsoftonline.com/%s/discovery/v2.0/keys", a.tenantID) 122 } 123 124 func (a *Azure) Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error) { 125 return a.provider.Verifier(&oidc.Config{ClientID: a.clientID}).Verify(ctx, rawIDToken) 126 } 127 128 func (a *Azure) FetchCertificates() (map[string]CertificateList, error) { 129 discoveryURL := a.KeyDiscoveryURL() 130 azureKeyDiscovery, err := DiscoverURL(discoveryURL) 131 if err != nil { 132 return nil, err 133 } 134 135 azureCertificates, err := azureKeyDiscovery.Map() 136 if err != nil { 137 return nil, err 138 } 139 return azureCertificates, nil 140 } 141 142 func (a *Azure) ValidateUser(certificates map[string]CertificateList, token string) (*User, error) { 143 var claims jwt.MapClaims 144 145 jwtValidator := JWTValidator(certificates, a.clientID) 146 147 azureToken, err := jwt.ParseWithClaims(token, &claims, jwtValidator) 148 if err != nil { 149 return nil, err 150 } 151 if !azureToken.Valid { 152 return nil, ErrAzureTokenExpired 153 } 154 155 return &User{ 156 Name: claims["name"].(string), 157 Email: strings.ToLower(claims["preferred_username"].(string)), 158 Expires: time.Unix(int64(claims["exp"].(float64)), 0), 159 }, nil 160 } 161 162 func (a *Azure) UserExistsInAzureAD(user string) error { 163 if a.dryRun { 164 fmt.Printf("NOOP: Would have checked if user %v exists in Azure AD\n", user) 165 return nil 166 } 167 168 type usersResponse struct { 169 Value []struct { 170 Email string `json:"userPrincipalName"` 171 } `json:"value"` 172 } 173 174 token, err := a.getBearerTokenForApplication() 175 if err != nil { 176 return err 177 } 178 179 r, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v?$filter=startswith(userPrincipalName,'%v')", AzureUsersEndpoint, user), nil) 180 if err != nil { 181 return err 182 } 183 r.Header.Add("Authorization", fmt.Sprintf("Bearer %v", token)) 184 185 httpClient := &http.Client{ 186 Timeout: time.Second * 10, 187 } 188 189 res, err := httpClient.Do(r) 190 if err != nil { 191 return err 192 } 193 194 resBytes, err := io.ReadAll(res.Body) 195 if err != nil { 196 return err 197 } 198 199 var users usersResponse 200 if err := json.Unmarshal(resBytes, &users); err != nil { 201 return err 202 } 203 204 switch len(users.Value) { 205 case 0: 206 return fmt.Errorf("no user exists in aad with email %v", user) 207 case 1: 208 return nil 209 default: 210 return fmt.Errorf("multiple users exist in aad for email %v", user) 211 } 212 } 213 214 func (a *Azure) ConvertEmailsToIdents(emails []string) ([]string, error) { 215 var idents []string 216 for _, e := range emails { 217 ident, err := a.identForEmail(e) 218 if err != nil { 219 return nil, err 220 } 221 if ident != "" { 222 idents = append(idents, ident) 223 } 224 } 225 return idents, nil 226 } 227 228 func (a *Azure) identForEmail(email string) (string, error) { 229 if a.dryRun { 230 a.log.Infof("NOOP: Running in dry run mode") 231 return fmt.Sprintf("d%v", rand.Intn(10000)+100000), nil 232 } 233 234 type identResponse struct { 235 Ident string `json:"onPremisesSamAccountName"` 236 } 237 238 token, err := a.getBearerTokenForApplication() 239 if err != nil { 240 return "", err 241 } 242 243 r, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v/%v?$select=onPremisesSamAccountName", AzureUsersEndpoint, email), nil) 244 if err != nil { 245 return "", err 246 } 247 r.Header.Add("Authorization", fmt.Sprintf("Bearer %v", token)) 248 249 httpClient := &http.Client{ 250 Timeout: time.Second * 10, 251 } 252 253 res, err := httpClient.Do(r) 254 if err != nil { 255 return "", err 256 } 257 258 resBytes, err := io.ReadAll(res.Body) 259 if err != nil { 260 return "", err 261 } 262 263 var identRes identResponse 264 if err := json.Unmarshal(resBytes, &identRes); err != nil { 265 return "", err 266 } 267 268 if identRes.Ident == "" { 269 a.log.Errorf("unable to get user ident for email %v", email) 270 } 271 272 return strings.ToLower(identRes.Ident), nil 273 } 274 275 func (a *Azure) getBearerTokenForApplication() (string, error) { 276 form := url.Values{} 277 form.Add("client_id", a.clientID) 278 form.Add("client_secret", a.clientSecret) 279 form.Add("scope", "https://graph.microsoft.com/.default") 280 form.Add("grant_type", "client_credentials") 281 282 req, err := http.NewRequest(http.MethodPost, endpoints.AzureAD(a.tenantID).TokenURL, strings.NewReader(form.Encode())) 283 if err != nil { 284 return "", err 285 } 286 287 httpClient := &http.Client{ 288 Timeout: time.Second * 10, 289 } 290 291 response, err := httpClient.Do(req) 292 if err != nil { 293 return "", err 294 } 295 296 var tokenResponse TokenResponse 297 if err := json.NewDecoder(response.Body).Decode(&tokenResponse); err != nil { 298 return "", err 299 } 300 301 return tokenResponse.AccessToken, nil 302 } 303 304 func (a *Azure) GetGroupID(groupMail string) (string, error) { 305 if a.dryRun { 306 a.log.Infof("NOOP: Running in dry run mode") 307 return "dummyID", nil 308 } 309 310 token, err := a.getBearerTokenForApplication() 311 if err != nil { 312 return "", err 313 } 314 315 params := url.Values{} 316 params.Add("$select", "id,displayName,mail") 317 params.Add("$filter", fmt.Sprintf("mail eq '%v'", groupMail)) 318 319 req, err := http.NewRequest(http.MethodGet, 320 AzureGroupsEndpoint+"?"+params.Encode(), 321 nil) 322 if err != nil { 323 return "", err 324 } 325 req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", token)) 326 327 httpClient := &http.Client{ 328 Timeout: time.Second * 10, 329 } 330 331 response, err := httpClient.Do(req) 332 if err != nil { 333 return "", err 334 } 335 336 var groupsResponse AzureGroupsWithIDResponse 337 if err := json.NewDecoder(response.Body).Decode(&groupsResponse); err != nil { 338 return "", err 339 } 340 341 if len(groupsResponse.Groups) > 0 { 342 return groupsResponse.Groups[0].ID, nil 343 } else { 344 return "", errors.New("group not found by the mail") 345 } 346 } 347 348 type CertificateList []*x509.Certificate 349 350 type KeyDiscovery struct { 351 Keys []Key `json:"keys"` 352 } 353 354 type EncodedCertificate string 355 356 type Key struct { 357 Kid string `json:"kid"` 358 X5c []EncodedCertificate `json:"x5c"` 359 } 360 361 // Map transform a KeyDiscovery object into a dictionary with "kid" as key 362 // and lists of decoded X509 certificates as values. 363 // 364 // Returns an error if any certificate does not decode. 365 func (k *KeyDiscovery) Map() (result map[string]CertificateList, err error) { 366 result = make(map[string]CertificateList) 367 368 for _, key := range k.Keys { 369 certList := make(CertificateList, 0) 370 for _, encodedCertificate := range key.X5c { 371 certificate, err := encodedCertificate.Decode() 372 if err != nil { 373 return nil, err 374 } 375 certList = append(certList, certificate) 376 } 377 result[key.Kid] = certList 378 } 379 380 return 381 } 382 383 // Decode a base64 encoded certificate into a X509 structure. 384 func (c EncodedCertificate) Decode() (*x509.Certificate, error) { 385 stream := strings.NewReader(string(c)) 386 decoder := base64.NewDecoder(base64.StdEncoding, stream) 387 key, err := io.ReadAll(decoder) 388 if err != nil { 389 return nil, err 390 } 391 392 return x509.ParseCertificate(key) 393 } 394 395 func DiscoverURL(url string) (*KeyDiscovery, error) { 396 response, err := http.Get(url) 397 if err != nil { 398 return nil, err 399 } 400 401 return Discover(response.Body) 402 } 403 404 func Discover(reader io.Reader) (*KeyDiscovery, error) { 405 document, err := io.ReadAll(reader) 406 if err != nil { 407 return nil, err 408 } 409 410 keyDiscovery := &KeyDiscovery{} 411 err = json.Unmarshal(document, keyDiscovery) 412 413 return keyDiscovery, err 414 } 415 416 func JWTValidator(certificates map[string]CertificateList, audience string) jwt.Keyfunc { 417 return func(token *jwt.Token) (interface{}, error) { 418 var certificateList CertificateList 419 var kid string 420 var ok bool 421 422 if claims, ok := token.Claims.(*jwt.MapClaims); !ok { 423 return nil, fmt.Errorf("unable to retrieve claims from token") 424 } else { 425 if valid := claims.VerifyAudience(audience, true); !valid { 426 return nil, fmt.Errorf("the token is not valid for this application") 427 } 428 } 429 430 if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { 431 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 432 } 433 434 if kid, ok = token.Header["kid"].(string); !ok { 435 return nil, fmt.Errorf("field 'kid' is of invalid type %T, should be string", token.Header["kid"]) 436 } 437 438 if certificateList, ok = certificates[kid]; !ok { 439 return nil, fmt.Errorf("kid '%s' not found in certificate list", kid) 440 } 441 442 for _, certificate := range certificateList { 443 return certificate.PublicKey, nil 444 } 445 446 return nil, fmt.Errorf("no certificate candidates for kid '%s'", kid) 447 } 448 }