github.com/pusher/oauth2_proxy@v3.2.0+incompatible/providers/azure.go (about) 1 package providers 2 3 import ( 4 "errors" 5 "fmt" 6 "log" 7 "net/http" 8 "net/url" 9 10 "github.com/bitly/go-simplejson" 11 "github.com/pusher/oauth2_proxy/api" 12 ) 13 14 // AzureProvider represents an Azure based Identity Provider 15 type AzureProvider struct { 16 *ProviderData 17 Tenant string 18 } 19 20 // NewAzureProvider initiates a new AzureProvider 21 func NewAzureProvider(p *ProviderData) *AzureProvider { 22 p.ProviderName = "Azure" 23 24 if p.ProfileURL == nil || p.ProfileURL.String() == "" { 25 p.ProfileURL = &url.URL{ 26 Scheme: "https", 27 Host: "graph.windows.net", 28 Path: "/me", 29 RawQuery: "api-version=1.6", 30 } 31 } 32 if p.ProtectedResource == nil || p.ProtectedResource.String() == "" { 33 p.ProtectedResource = &url.URL{ 34 Scheme: "https", 35 Host: "graph.windows.net", 36 } 37 } 38 if p.Scope == "" { 39 p.Scope = "openid" 40 } 41 42 return &AzureProvider{ProviderData: p} 43 } 44 45 // Configure defaults the AzureProvider configuration options 46 func (p *AzureProvider) Configure(tenant string) { 47 p.Tenant = tenant 48 if tenant == "" { 49 p.Tenant = "common" 50 } 51 52 if p.LoginURL == nil || p.LoginURL.String() == "" { 53 p.LoginURL = &url.URL{ 54 Scheme: "https", 55 Host: "login.microsoftonline.com", 56 Path: "/" + p.Tenant + "/oauth2/authorize"} 57 } 58 if p.RedeemURL == nil || p.RedeemURL.String() == "" { 59 p.RedeemURL = &url.URL{ 60 Scheme: "https", 61 Host: "login.microsoftonline.com", 62 Path: "/" + p.Tenant + "/oauth2/token", 63 } 64 } 65 } 66 67 func getAzureHeader(accessToken string) http.Header { 68 header := make(http.Header) 69 header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) 70 return header 71 } 72 73 func getEmailFromJSON(json *simplejson.Json) (string, error) { 74 var email string 75 var err error 76 77 email, err = json.Get("mail").String() 78 79 if err != nil || email == "" { 80 otherMails, otherMailsErr := json.Get("otherMails").Array() 81 if len(otherMails) > 0 { 82 email = otherMails[0].(string) 83 } 84 err = otherMailsErr 85 } 86 87 return email, err 88 } 89 90 // GetEmailAddress returns the Account email address 91 func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { 92 var email string 93 var err error 94 95 if s.AccessToken == "" { 96 return "", errors.New("missing access token") 97 } 98 req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) 99 if err != nil { 100 return "", err 101 } 102 req.Header = getAzureHeader(s.AccessToken) 103 104 json, err := api.Request(req) 105 106 if err != nil { 107 return "", err 108 } 109 110 email, err = getEmailFromJSON(json) 111 112 if err == nil && email != "" { 113 return email, err 114 } 115 116 email, err = json.Get("userPrincipalName").String() 117 118 if err != nil { 119 log.Printf("failed making request %s", err) 120 return "", err 121 } 122 123 if email == "" { 124 log.Printf("failed to get email address") 125 return "", err 126 } 127 128 return email, err 129 }