github.com/pusher/oauth2_proxy@v3.2.0+incompatible/providers/github.go (about)

     1  package providers
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"log"
     8  	"net/http"
     9  	"net/url"
    10  	"path"
    11  	"strconv"
    12  	"strings"
    13  )
    14  
    15  // GitHubProvider represents an GitHub based Identity Provider
    16  type GitHubProvider struct {
    17  	*ProviderData
    18  	Org  string
    19  	Team string
    20  }
    21  
    22  // NewGitHubProvider initiates a new GitHubProvider
    23  func NewGitHubProvider(p *ProviderData) *GitHubProvider {
    24  	p.ProviderName = "GitHub"
    25  	if p.LoginURL == nil || p.LoginURL.String() == "" {
    26  		p.LoginURL = &url.URL{
    27  			Scheme: "https",
    28  			Host:   "github.com",
    29  			Path:   "/login/oauth/authorize",
    30  		}
    31  	}
    32  	if p.RedeemURL == nil || p.RedeemURL.String() == "" {
    33  		p.RedeemURL = &url.URL{
    34  			Scheme: "https",
    35  			Host:   "github.com",
    36  			Path:   "/login/oauth/access_token",
    37  		}
    38  	}
    39  	// ValidationURL is the API Base URL
    40  	if p.ValidateURL == nil || p.ValidateURL.String() == "" {
    41  		p.ValidateURL = &url.URL{
    42  			Scheme: "https",
    43  			Host:   "api.github.com",
    44  			Path:   "/",
    45  		}
    46  	}
    47  	if p.Scope == "" {
    48  		p.Scope = "user:email"
    49  	}
    50  	return &GitHubProvider{ProviderData: p}
    51  }
    52  
    53  // SetOrgTeam adds GitHub org reading parameters to the OAuth2 scope
    54  func (p *GitHubProvider) SetOrgTeam(org, team string) {
    55  	p.Org = org
    56  	p.Team = team
    57  	if org != "" || team != "" {
    58  		p.Scope += " read:org"
    59  	}
    60  }
    61  
    62  func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
    63  	// https://developer.github.com/v3/orgs/#list-your-organizations
    64  
    65  	var orgs []struct {
    66  		Login string `json:"login"`
    67  	}
    68  
    69  	type orgsPage []struct {
    70  		Login string `json:"login"`
    71  	}
    72  
    73  	pn := 1
    74  	for {
    75  		params := url.Values{
    76  			"limit": {"200"},
    77  			"page":  {strconv.Itoa(pn)},
    78  		}
    79  
    80  		endpoint := &url.URL{
    81  			Scheme:   p.ValidateURL.Scheme,
    82  			Host:     p.ValidateURL.Host,
    83  			Path:     path.Join(p.ValidateURL.Path, "/user/orgs"),
    84  			RawQuery: params.Encode(),
    85  		}
    86  		req, _ := http.NewRequest("GET", endpoint.String(), nil)
    87  		req.Header.Set("Accept", "application/vnd.github.v3+json")
    88  		req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken))
    89  		resp, err := http.DefaultClient.Do(req)
    90  		if err != nil {
    91  			return false, err
    92  		}
    93  
    94  		body, err := ioutil.ReadAll(resp.Body)
    95  		resp.Body.Close()
    96  		if err != nil {
    97  			return false, err
    98  		}
    99  		if resp.StatusCode != 200 {
   100  			return false, fmt.Errorf(
   101  				"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
   102  		}
   103  
   104  		var op orgsPage
   105  		if err := json.Unmarshal(body, &op); err != nil {
   106  			return false, err
   107  		}
   108  		if len(op) == 0 {
   109  			break
   110  		}
   111  
   112  		orgs = append(orgs, op...)
   113  		pn++
   114  	}
   115  
   116  	var presentOrgs []string
   117  	for _, org := range orgs {
   118  		if p.Org == org.Login {
   119  			log.Printf("Found Github Organization: %q", org.Login)
   120  			return true, nil
   121  		}
   122  		presentOrgs = append(presentOrgs, org.Login)
   123  	}
   124  
   125  	log.Printf("Missing Organization:%q in %v", p.Org, presentOrgs)
   126  	return false, nil
   127  }
   128  
   129  func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
   130  	// https://developer.github.com/v3/orgs/teams/#list-user-teams
   131  
   132  	var teams []struct {
   133  		Name string `json:"name"`
   134  		Slug string `json:"slug"`
   135  		Org  struct {
   136  			Login string `json:"login"`
   137  		} `json:"organization"`
   138  	}
   139  
   140  	params := url.Values{
   141  		"limit": {"200"},
   142  	}
   143  
   144  	endpoint := &url.URL{
   145  		Scheme:   p.ValidateURL.Scheme,
   146  		Host:     p.ValidateURL.Host,
   147  		Path:     path.Join(p.ValidateURL.Path, "/user/teams"),
   148  		RawQuery: params.Encode(),
   149  	}
   150  	req, _ := http.NewRequest("GET", endpoint.String(), nil)
   151  	req.Header.Set("Accept", "application/vnd.github.v3+json")
   152  	req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken))
   153  	resp, err := http.DefaultClient.Do(req)
   154  	if err != nil {
   155  		return false, err
   156  	}
   157  
   158  	body, err := ioutil.ReadAll(resp.Body)
   159  	resp.Body.Close()
   160  	if err != nil {
   161  		return false, err
   162  	}
   163  	if resp.StatusCode != 200 {
   164  		return false, fmt.Errorf(
   165  			"got %d from %q %s", resp.StatusCode, endpoint.String(), body)
   166  	}
   167  
   168  	if err := json.Unmarshal(body, &teams); err != nil {
   169  		return false, fmt.Errorf("%s unmarshaling %s", err, body)
   170  	}
   171  
   172  	var hasOrg bool
   173  	presentOrgs := make(map[string]bool)
   174  	var presentTeams []string
   175  	for _, team := range teams {
   176  		presentOrgs[team.Org.Login] = true
   177  		if p.Org == team.Org.Login {
   178  			hasOrg = true
   179  			ts := strings.Split(p.Team, ",")
   180  			for _, t := range ts {
   181  				if t == team.Slug {
   182  					log.Printf("Found Github Organization:%q Team:%q (Name:%q)", team.Org.Login, team.Slug, team.Name)
   183  					return true, nil
   184  				}
   185  			}
   186  			presentTeams = append(presentTeams, team.Slug)
   187  		}
   188  	}
   189  	if hasOrg {
   190  		log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams)
   191  	} else {
   192  		var allOrgs []string
   193  		for org := range presentOrgs {
   194  			allOrgs = append(allOrgs, org)
   195  		}
   196  		log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs)
   197  	}
   198  	return false, nil
   199  }
   200  
   201  // GetEmailAddress returns the Account email address
   202  func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
   203  
   204  	var emails []struct {
   205  		Email    string `json:"email"`
   206  		Primary  bool   `json:"primary"`
   207  		Verified bool   `json:"verified"`
   208  	}
   209  
   210  	// if we require an Org or Team, check that first
   211  	if p.Org != "" {
   212  		if p.Team != "" {
   213  			if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok {
   214  				return "", err
   215  			}
   216  		} else {
   217  			if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok {
   218  				return "", err
   219  			}
   220  		}
   221  	}
   222  
   223  	endpoint := &url.URL{
   224  		Scheme: p.ValidateURL.Scheme,
   225  		Host:   p.ValidateURL.Host,
   226  		Path:   path.Join(p.ValidateURL.Path, "/user/emails"),
   227  	}
   228  	req, _ := http.NewRequest("GET", endpoint.String(), nil)
   229  	req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
   230  	resp, err := http.DefaultClient.Do(req)
   231  	if err != nil {
   232  		return "", err
   233  	}
   234  	body, err := ioutil.ReadAll(resp.Body)
   235  	resp.Body.Close()
   236  	if err != nil {
   237  		return "", err
   238  	}
   239  
   240  	if resp.StatusCode != 200 {
   241  		return "", fmt.Errorf("got %d from %q %s",
   242  			resp.StatusCode, endpoint.String(), body)
   243  	}
   244  
   245  	log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
   246  
   247  	if err := json.Unmarshal(body, &emails); err != nil {
   248  		return "", fmt.Errorf("%s unmarshaling %s", err, body)
   249  	}
   250  
   251  	for _, email := range emails {
   252  		if email.Primary && email.Verified {
   253  			return email.Email, nil
   254  		}
   255  	}
   256  
   257  	return "", nil
   258  }
   259  
   260  // GetUserName returns the Account user name
   261  func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
   262  	var user struct {
   263  		Login string `json:"login"`
   264  		Email string `json:"email"`
   265  	}
   266  
   267  	endpoint := &url.URL{
   268  		Scheme: p.ValidateURL.Scheme,
   269  		Host:   p.ValidateURL.Host,
   270  		Path:   path.Join(p.ValidateURL.Path, "/user"),
   271  	}
   272  
   273  	req, err := http.NewRequest("GET", endpoint.String(), nil)
   274  	if err != nil {
   275  		return "", fmt.Errorf("could not create new GET request: %v", err)
   276  	}
   277  
   278  	req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
   279  	resp, err := http.DefaultClient.Do(req)
   280  	if err != nil {
   281  		return "", err
   282  	}
   283  
   284  	body, err := ioutil.ReadAll(resp.Body)
   285  	defer resp.Body.Close()
   286  	if err != nil {
   287  		return "", err
   288  	}
   289  
   290  	if resp.StatusCode != 200 {
   291  		return "", fmt.Errorf("got %d from %q %s",
   292  			resp.StatusCode, endpoint.String(), body)
   293  	}
   294  
   295  	log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
   296  
   297  	if err := json.Unmarshal(body, &user); err != nil {
   298  		return "", fmt.Errorf("%s unmarshaling %s", err, body)
   299  	}
   300  
   301  	return user.Login, nil
   302  }