github.com/pusher/oauth2_proxy@v3.2.0+incompatible/providers/azure.go (about)

     1  package providers
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"log"
     7  	"net/http"
     8  	"net/url"
     9  
    10  	"github.com/bitly/go-simplejson"
    11  	"github.com/pusher/oauth2_proxy/api"
    12  )
    13  
    14  // AzureProvider represents an Azure based Identity Provider
    15  type AzureProvider struct {
    16  	*ProviderData
    17  	Tenant string
    18  }
    19  
    20  // NewAzureProvider initiates a new AzureProvider
    21  func NewAzureProvider(p *ProviderData) *AzureProvider {
    22  	p.ProviderName = "Azure"
    23  
    24  	if p.ProfileURL == nil || p.ProfileURL.String() == "" {
    25  		p.ProfileURL = &url.URL{
    26  			Scheme:   "https",
    27  			Host:     "graph.windows.net",
    28  			Path:     "/me",
    29  			RawQuery: "api-version=1.6",
    30  		}
    31  	}
    32  	if p.ProtectedResource == nil || p.ProtectedResource.String() == "" {
    33  		p.ProtectedResource = &url.URL{
    34  			Scheme: "https",
    35  			Host:   "graph.windows.net",
    36  		}
    37  	}
    38  	if p.Scope == "" {
    39  		p.Scope = "openid"
    40  	}
    41  
    42  	return &AzureProvider{ProviderData: p}
    43  }
    44  
    45  // Configure defaults the AzureProvider configuration options
    46  func (p *AzureProvider) Configure(tenant string) {
    47  	p.Tenant = tenant
    48  	if tenant == "" {
    49  		p.Tenant = "common"
    50  	}
    51  
    52  	if p.LoginURL == nil || p.LoginURL.String() == "" {
    53  		p.LoginURL = &url.URL{
    54  			Scheme: "https",
    55  			Host:   "login.microsoftonline.com",
    56  			Path:   "/" + p.Tenant + "/oauth2/authorize"}
    57  	}
    58  	if p.RedeemURL == nil || p.RedeemURL.String() == "" {
    59  		p.RedeemURL = &url.URL{
    60  			Scheme: "https",
    61  			Host:   "login.microsoftonline.com",
    62  			Path:   "/" + p.Tenant + "/oauth2/token",
    63  		}
    64  	}
    65  }
    66  
    67  func getAzureHeader(accessToken string) http.Header {
    68  	header := make(http.Header)
    69  	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
    70  	return header
    71  }
    72  
    73  func getEmailFromJSON(json *simplejson.Json) (string, error) {
    74  	var email string
    75  	var err error
    76  
    77  	email, err = json.Get("mail").String()
    78  
    79  	if err != nil || email == "" {
    80  		otherMails, otherMailsErr := json.Get("otherMails").Array()
    81  		if len(otherMails) > 0 {
    82  			email = otherMails[0].(string)
    83  		}
    84  		err = otherMailsErr
    85  	}
    86  
    87  	return email, err
    88  }
    89  
    90  // GetEmailAddress returns the Account email address
    91  func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) {
    92  	var email string
    93  	var err error
    94  
    95  	if s.AccessToken == "" {
    96  		return "", errors.New("missing access token")
    97  	}
    98  	req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
    99  	if err != nil {
   100  		return "", err
   101  	}
   102  	req.Header = getAzureHeader(s.AccessToken)
   103  
   104  	json, err := api.Request(req)
   105  
   106  	if err != nil {
   107  		return "", err
   108  	}
   109  
   110  	email, err = getEmailFromJSON(json)
   111  
   112  	if err == nil && email != "" {
   113  		return email, err
   114  	}
   115  
   116  	email, err = json.Get("userPrincipalName").String()
   117  
   118  	if err != nil {
   119  		log.Printf("failed making request %s", err)
   120  		return "", err
   121  	}
   122  
   123  	if email == "" {
   124  		log.Printf("failed to get email address")
   125  		return "", err
   126  	}
   127  
   128  	return email, err
   129  }