github.com/nais/knorten@v0.0.0-20240104110906-55926958e361/pkg/api/auth/auth.go (about) 1 package auth 2 3 import ( 4 "context" 5 "crypto/x509" 6 "encoding/base64" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "io" 11 "net/http" 12 "net/url" 13 "os" 14 "strings" 15 "time" 16 17 "github.com/coreos/go-oidc" 18 "github.com/golang-jwt/jwt/v4" 19 "github.com/sirupsen/logrus" 20 "golang.org/x/exp/rand" 21 "golang.org/x/oauth2" 22 "golang.org/x/oauth2/endpoints" 23 ) 24 25 type OauthConfig struct { 26 ClientID string 27 ClientSecret string 28 TenantID string 29 } 30 31 type Session struct { 32 Email string `json:"preferred_username"` 33 Name string `json:"name"` 34 AccessToken string 35 Token string 36 Expires time.Time 37 IsAdmin bool 38 } 39 40 type Azure struct { 41 oauth2.Config 42 43 clientID string 44 clientSecret string 45 tenantID string 46 dryRun bool 47 provider *oidc.Provider 48 log *logrus.Entry 49 } 50 51 type User struct { 52 Name string 53 Email string 54 Expires time.Time 55 } 56 57 type AzureGroupsWithIDResponse struct { 58 Groups []AzureGroupWithID `json:"value"` 59 } 60 61 type AzureGroupWithID struct { 62 DisplayName string `json:"displayName"` 63 ID string `json:"id"` 64 Mail string `json:"mail"` 65 } 66 67 type TokenResponse struct { 68 AccessToken string `json:"access_token"` 69 } 70 71 var ErrAzureTokenExpired = fmt.Errorf("token expired") 72 73 const ( 74 AzureUsersEndpoint = "https://graph.microsoft.com/v1.0/users" 75 AzureGroupsEndpoint = "https://graph.microsoft.com/v1.0/groups" 76 ) 77 78 func NewAzureClient(dryRun bool, clientID, clientSecret, tenantID string, log *logrus.Entry) (*Azure, error) { 79 if dryRun { 80 log.Infof("NOOP: Running in dry run mode") 81 return &Azure{ 82 dryRun: dryRun, 83 log: log, 84 }, nil 85 } 86 87 provider, err := oidc.NewProvider(context.Background(), fmt.Sprintf("https://login.microsoftonline.com/%v/v2.0", tenantID)) 88 if err != nil { 89 return nil, err 90 } 91 92 a := &Azure{ 93 clientID: clientID, 94 clientSecret: clientSecret, 95 tenantID: tenantID, 96 provider: provider, 97 dryRun: dryRun, 98 log: log, 99 } 100 101 a.setupOAuth2() 102 return a, nil 103 } 104 105 func (a *Azure) setupOAuth2() { 106 redirectURL := "https://knorten.knada.io/oauth2/callback" 107 if os.Getenv("GIN_MODE") != "release" { 108 redirectURL = "http://localhost:8080/oauth2/callback" 109 } 110 111 a.Config = oauth2.Config{ 112 ClientID: a.clientID, 113 ClientSecret: a.clientSecret, 114 Endpoint: a.provider.Endpoint(), 115 RedirectURL: redirectURL, 116 Scopes: []string{"openid", fmt.Sprintf("%s/.default", a.clientID)}, 117 } 118 } 119 120 func (a *Azure) KeyDiscoveryURL() string { 121 return fmt.Sprintf("https://login.microsoftonline.com/%s/discovery/v2.0/keys", a.tenantID) 122 } 123 124 func (a *Azure) Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error) { 125 return a.provider.Verifier(&oidc.Config{ClientID: a.clientID}).Verify(ctx, rawIDToken) 126 } 127 128 func (a *Azure) FetchCertificates() (map[string]CertificateList, error) { 129 discoveryURL := a.KeyDiscoveryURL() 130 azureKeyDiscovery, err := DiscoverURL(discoveryURL) 131 if err != nil { 132 return nil, err 133 } 134 135 azureCertificates, err := azureKeyDiscovery.Map() 136 if err != nil { 137 return nil, err 138 } 139 return azureCertificates, nil 140 } 141 142 func (a *Azure) ValidateUser(certificates map[string]CertificateList, token string) (*User, error) { 143 var claims jwt.MapClaims 144 145 jwtValidator := JWTValidator(certificates, a.clientID) 146 147 azureToken, err := jwt.ParseWithClaims(token, &claims, jwtValidator) 148 if err != nil { 149 return nil, err 150 } 151 if !azureToken.Valid { 152 return nil, ErrAzureTokenExpired 153 } 154 155 return &User{ 156 Name: claims["name"].(string), 157 Email: strings.ToLower(claims["preferred_username"].(string)), 158 Expires: time.Unix(int64(claims["exp"].(float64)), 0), 159 }, nil 160 } 161 162 func (a *Azure) UserExistsInAzureAD(user string) error { 163 if a.dryRun { 164 fmt.Printf("NOOP: Would have checked if user %v exists in Azure AD\n", user) 165 return nil 166 } 167 168 type usersResponse struct { 169 Value []struct { 170 Email string `json:"userPrincipalName"` 171 } `json:"value"` 172 } 173 174 token, err := a.getBearerTokenForApplication() 175 if err != nil { 176 return err 177 } 178 179 r, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v?$filter=startswith(userPrincipalName,'%v')", AzureUsersEndpoint, user), nil) 180 if err != nil { 181 return err 182 } 183 r.Header.Add("Authorization", fmt.Sprintf("Bearer %v", token)) 184 185 httpClient := &http.Client{ 186 Timeout: time.Second * 10, 187 } 188 189 res, err := httpClient.Do(r) 190 if err != nil { 191 return err 192 } 193 194 resBytes, err := io.ReadAll(res.Body) 195 if err != nil { 196 return err 197 } 198 199 var users usersResponse 200 if err := json.Unmarshal(resBytes, &users); err != nil { 201 return err 202 } 203 204 switch len(users.Value) { 205 case 0: 206 return fmt.Errorf("no user exists in aad with email %v", user) 207 case 1: 208 return nil 209 default: 210 return fmt.Errorf("multiple users exist in aad for email %v", user) 211 } 212 } 213 214 func (a *Azure) ConvertEmailsToIdents(emails []string) ([]string, error) { 215 var idents []string 216 for _, e := range emails { 217 ident, err := a.identForEmail(e) 218 if err != nil { 219 return nil, err 220 } 221 idents = append(idents, ident) 222 } 223 return idents, nil 224 } 225 226 func (a *Azure) identForEmail(email string) (string, error) { 227 if a.dryRun { 228 a.log.Infof("NOOP: Running in dry run mode") 229 return fmt.Sprintf("d%v", rand.Intn(10000)+100000), nil 230 } 231 232 type identResponse struct { 233 Ident string `json:"onPremisesSamAccountName"` 234 } 235 236 token, err := a.getBearerTokenForApplication() 237 if err != nil { 238 return "", err 239 } 240 241 r, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v/%v?$select=onPremisesSamAccountName", AzureUsersEndpoint, email), nil) 242 if err != nil { 243 return "", err 244 } 245 r.Header.Add("Authorization", fmt.Sprintf("Bearer %v", token)) 246 247 httpClient := &http.Client{ 248 Timeout: time.Second * 10, 249 } 250 251 res, err := httpClient.Do(r) 252 if err != nil { 253 return "", err 254 } 255 256 resBytes, err := io.ReadAll(res.Body) 257 if err != nil { 258 return "", err 259 } 260 261 var identRes identResponse 262 if err := json.Unmarshal(resBytes, &identRes); err != nil { 263 return "", err 264 } 265 266 if identRes.Ident == "" { 267 return "", fmt.Errorf("unable to get user ident for email %v", email) 268 } 269 270 return strings.ToLower(identRes.Ident), nil 271 } 272 273 func (a *Azure) getBearerTokenForApplication() (string, error) { 274 form := url.Values{} 275 form.Add("client_id", a.clientID) 276 form.Add("client_secret", a.clientSecret) 277 form.Add("scope", "https://graph.microsoft.com/.default") 278 form.Add("grant_type", "client_credentials") 279 280 req, err := http.NewRequest(http.MethodPost, endpoints.AzureAD(a.tenantID).TokenURL, strings.NewReader(form.Encode())) 281 if err != nil { 282 return "", err 283 } 284 285 httpClient := &http.Client{ 286 Timeout: time.Second * 10, 287 } 288 289 response, err := httpClient.Do(req) 290 if err != nil { 291 return "", err 292 } 293 294 var tokenResponse TokenResponse 295 if err := json.NewDecoder(response.Body).Decode(&tokenResponse); err != nil { 296 return "", err 297 } 298 299 return tokenResponse.AccessToken, nil 300 } 301 302 func (a *Azure) GetGroupID(groupMail string) (string, error) { 303 if a.dryRun { 304 a.log.Infof("NOOP: Running in dry run mode") 305 return "dummyID", nil 306 } 307 308 token, err := a.getBearerTokenForApplication() 309 if err != nil { 310 return "", err 311 } 312 313 params := url.Values{} 314 params.Add("$select", "id,displayName,mail") 315 params.Add("$filter", fmt.Sprintf("mail eq '%v'", groupMail)) 316 317 req, err := http.NewRequest(http.MethodGet, 318 AzureGroupsEndpoint+"?"+params.Encode(), 319 nil) 320 if err != nil { 321 return "", err 322 } 323 req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", token)) 324 325 httpClient := &http.Client{ 326 Timeout: time.Second * 10, 327 } 328 329 response, err := httpClient.Do(req) 330 if err != nil { 331 return "", err 332 } 333 334 var groupsResponse AzureGroupsWithIDResponse 335 if err := json.NewDecoder(response.Body).Decode(&groupsResponse); err != nil { 336 return "", err 337 } 338 339 if len(groupsResponse.Groups) > 0 { 340 return groupsResponse.Groups[0].ID, nil 341 } else { 342 return "", errors.New("group not found by the mail") 343 } 344 } 345 346 type CertificateList []*x509.Certificate 347 348 type KeyDiscovery struct { 349 Keys []Key `json:"keys"` 350 } 351 352 type EncodedCertificate string 353 354 type Key struct { 355 Kid string `json:"kid"` 356 X5c []EncodedCertificate `json:"x5c"` 357 } 358 359 // Map transform a KeyDiscovery object into a dictionary with "kid" as key 360 // and lists of decoded X509 certificates as values. 361 // 362 // Returns an error if any certificate does not decode. 363 func (k *KeyDiscovery) Map() (result map[string]CertificateList, err error) { 364 result = make(map[string]CertificateList) 365 366 for _, key := range k.Keys { 367 certList := make(CertificateList, 0) 368 for _, encodedCertificate := range key.X5c { 369 certificate, err := encodedCertificate.Decode() 370 if err != nil { 371 return nil, err 372 } 373 certList = append(certList, certificate) 374 } 375 result[key.Kid] = certList 376 } 377 378 return 379 } 380 381 // Decode a base64 encoded certificate into a X509 structure. 382 func (c EncodedCertificate) Decode() (*x509.Certificate, error) { 383 stream := strings.NewReader(string(c)) 384 decoder := base64.NewDecoder(base64.StdEncoding, stream) 385 key, err := io.ReadAll(decoder) 386 if err != nil { 387 return nil, err 388 } 389 390 return x509.ParseCertificate(key) 391 } 392 393 func DiscoverURL(url string) (*KeyDiscovery, error) { 394 response, err := http.Get(url) 395 if err != nil { 396 return nil, err 397 } 398 399 return Discover(response.Body) 400 } 401 402 func Discover(reader io.Reader) (*KeyDiscovery, error) { 403 document, err := io.ReadAll(reader) 404 if err != nil { 405 return nil, err 406 } 407 408 keyDiscovery := &KeyDiscovery{} 409 err = json.Unmarshal(document, keyDiscovery) 410 411 return keyDiscovery, err 412 } 413 414 func JWTValidator(certificates map[string]CertificateList, audience string) jwt.Keyfunc { 415 return func(token *jwt.Token) (interface{}, error) { 416 var certificateList CertificateList 417 var kid string 418 var ok bool 419 420 if claims, ok := token.Claims.(*jwt.MapClaims); !ok { 421 return nil, fmt.Errorf("unable to retrieve claims from token") 422 } else { 423 if valid := claims.VerifyAudience(audience, true); !valid { 424 return nil, fmt.Errorf("the token is not valid for this application") 425 } 426 } 427 428 if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { 429 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) 430 } 431 432 if kid, ok = token.Header["kid"].(string); !ok { 433 return nil, fmt.Errorf("field 'kid' is of invalid type %T, should be string", token.Header["kid"]) 434 } 435 436 if certificateList, ok = certificates[kid]; !ok { 437 return nil, fmt.Errorf("kid '%s' not found in certificate list", kid) 438 } 439 440 for _, certificate := range certificateList { 441 return certificate.PublicKey, nil 442 } 443 444 return nil, fmt.Errorf("no certificate candidates for kid '%s'", kid) 445 } 446 }