github.com/pusher/oauth2_proxy@v3.2.0+incompatible/providers/google.go (about)

     1  package providers
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"log"
    12  	"net/http"
    13  	"net/url"
    14  	"strings"
    15  	"time"
    16  
    17  	"golang.org/x/oauth2"
    18  	"golang.org/x/oauth2/google"
    19  	admin "google.golang.org/api/admin/directory/v1"
    20  	"google.golang.org/api/googleapi"
    21  )
    22  
    23  // GoogleProvider represents an Google based Identity Provider
    24  type GoogleProvider struct {
    25  	*ProviderData
    26  	RedeemRefreshURL *url.URL
    27  	// GroupValidator is a function that determines if the passed email is in
    28  	// the configured Google group.
    29  	GroupValidator func(string) bool
    30  }
    31  
    32  // NewGoogleProvider initiates a new GoogleProvider
    33  func NewGoogleProvider(p *ProviderData) *GoogleProvider {
    34  	p.ProviderName = "Google"
    35  	if p.LoginURL.String() == "" {
    36  		p.LoginURL = &url.URL{Scheme: "https",
    37  			Host: "accounts.google.com",
    38  			Path: "/o/oauth2/auth",
    39  			// to get a refresh token. see https://developers.google.com/identity/protocols/OAuth2WebServer#offline
    40  			RawQuery: "access_type=offline",
    41  		}
    42  	}
    43  	if p.RedeemURL.String() == "" {
    44  		p.RedeemURL = &url.URL{Scheme: "https",
    45  			Host: "www.googleapis.com",
    46  			Path: "/oauth2/v3/token"}
    47  	}
    48  	if p.ValidateURL.String() == "" {
    49  		p.ValidateURL = &url.URL{Scheme: "https",
    50  			Host: "www.googleapis.com",
    51  			Path: "/oauth2/v1/tokeninfo"}
    52  	}
    53  	if p.Scope == "" {
    54  		p.Scope = "profile email"
    55  	}
    56  
    57  	return &GoogleProvider{
    58  		ProviderData: p,
    59  		// Set a default GroupValidator to just always return valid (true), it will
    60  		// be overwritten if we configured a Google group restriction.
    61  		GroupValidator: func(email string) bool {
    62  			return true
    63  		},
    64  	}
    65  }
    66  
    67  func emailFromIDToken(idToken string) (string, error) {
    68  
    69  	// id_token is a base64 encode ID token payload
    70  	// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
    71  	jwt := strings.Split(idToken, ".")
    72  	jwtData := strings.TrimSuffix(jwt[1], "=")
    73  	b, err := base64.RawURLEncoding.DecodeString(jwtData)
    74  	if err != nil {
    75  		return "", err
    76  	}
    77  
    78  	var email struct {
    79  		Email         string `json:"email"`
    80  		EmailVerified bool   `json:"email_verified"`
    81  	}
    82  	err = json.Unmarshal(b, &email)
    83  	if err != nil {
    84  		return "", err
    85  	}
    86  	if email.Email == "" {
    87  		return "", errors.New("missing email")
    88  	}
    89  	if !email.EmailVerified {
    90  		return "", fmt.Errorf("email %s not listed as verified", email.Email)
    91  	}
    92  	return email.Email, nil
    93  }
    94  
    95  // Redeem exchanges the OAuth2 authentication token for an ID token
    96  func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
    97  	if code == "" {
    98  		err = errors.New("missing code")
    99  		return
   100  	}
   101  
   102  	params := url.Values{}
   103  	params.Add("redirect_uri", redirectURL)
   104  	params.Add("client_id", p.ClientID)
   105  	params.Add("client_secret", p.ClientSecret)
   106  	params.Add("code", code)
   107  	params.Add("grant_type", "authorization_code")
   108  	var req *http.Request
   109  	req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
   110  	if err != nil {
   111  		return
   112  	}
   113  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   114  
   115  	resp, err := http.DefaultClient.Do(req)
   116  	if err != nil {
   117  		return
   118  	}
   119  	var body []byte
   120  	body, err = ioutil.ReadAll(resp.Body)
   121  	resp.Body.Close()
   122  	if err != nil {
   123  		return
   124  	}
   125  
   126  	if resp.StatusCode != 200 {
   127  		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
   128  		return
   129  	}
   130  
   131  	var jsonResponse struct {
   132  		AccessToken  string `json:"access_token"`
   133  		RefreshToken string `json:"refresh_token"`
   134  		ExpiresIn    int64  `json:"expires_in"`
   135  		IDToken      string `json:"id_token"`
   136  	}
   137  	err = json.Unmarshal(body, &jsonResponse)
   138  	if err != nil {
   139  		return
   140  	}
   141  	var email string
   142  	email, err = emailFromIDToken(jsonResponse.IDToken)
   143  	if err != nil {
   144  		return
   145  	}
   146  	s = &SessionState{
   147  		AccessToken:  jsonResponse.AccessToken,
   148  		IDToken:      jsonResponse.IDToken,
   149  		ExpiresOn:    time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
   150  		RefreshToken: jsonResponse.RefreshToken,
   151  		Email:        email,
   152  	}
   153  	return
   154  }
   155  
   156  // SetGroupRestriction configures the GoogleProvider to restrict access to the
   157  // specified group(s). AdminEmail has to be an administrative email on the domain that is
   158  // checked. CredentialsFile is the path to a json file containing a Google service
   159  // account credentials.
   160  func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) {
   161  	adminService := getAdminService(adminEmail, credentialsReader)
   162  	p.GroupValidator = func(email string) bool {
   163  		return userInGroup(adminService, groups, email)
   164  	}
   165  }
   166  
   167  func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Service {
   168  	data, err := ioutil.ReadAll(credentialsReader)
   169  	if err != nil {
   170  		log.Fatal("can't read Google credentials file:", err)
   171  	}
   172  	conf, err := google.JWTConfigFromJSON(data, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope)
   173  	if err != nil {
   174  		log.Fatal("can't load Google credentials file:", err)
   175  	}
   176  	conf.Subject = adminEmail
   177  
   178  	client := conf.Client(oauth2.NoContext)
   179  	adminService, err := admin.New(client)
   180  	if err != nil {
   181  		log.Fatal(err)
   182  	}
   183  	return adminService
   184  }
   185  
   186  func userInGroup(service *admin.Service, groups []string, email string) bool {
   187  	user, err := fetchUser(service, email)
   188  	if err != nil {
   189  		log.Printf("error fetching user: %v", err)
   190  		return false
   191  	}
   192  	id := user.Id
   193  	custID := user.CustomerId
   194  
   195  	for _, group := range groups {
   196  		members, err := fetchGroupMembers(service, group)
   197  		if err != nil {
   198  			if err, ok := err.(*googleapi.Error); ok && err.Code == 404 {
   199  				log.Printf("error fetching members for group %s: group does not exist", group)
   200  			} else {
   201  				log.Printf("error fetching group members: %v", err)
   202  				return false
   203  			}
   204  		}
   205  
   206  		for _, member := range members {
   207  			switch member.Type {
   208  			case "CUSTOMER":
   209  				if member.Id == custID {
   210  					return true
   211  				}
   212  			case "USER":
   213  				if member.Id == id {
   214  					return true
   215  				}
   216  			}
   217  		}
   218  	}
   219  	return false
   220  }
   221  
   222  func fetchUser(service *admin.Service, email string) (*admin.User, error) {
   223  	user, err := service.Users.Get(email).Do()
   224  	return user, err
   225  }
   226  
   227  func fetchGroupMembers(service *admin.Service, group string) ([]*admin.Member, error) {
   228  	members := []*admin.Member{}
   229  	pageToken := ""
   230  	for {
   231  		req := service.Members.List(group)
   232  		if pageToken != "" {
   233  			req.PageToken(pageToken)
   234  		}
   235  		r, err := req.Do()
   236  		if err != nil {
   237  			return nil, err
   238  		}
   239  		for _, member := range r.Members {
   240  			members = append(members, member)
   241  		}
   242  		if r.NextPageToken == "" {
   243  			break
   244  		}
   245  		pageToken = r.NextPageToken
   246  	}
   247  	return members, nil
   248  }
   249  
   250  // ValidateGroup validates that the provided email exists in the configured Google
   251  // group(s).
   252  func (p *GoogleProvider) ValidateGroup(email string) bool {
   253  	return p.GroupValidator(email)
   254  }
   255  
   256  // RefreshSessionIfNeeded checks if the session has expired and uses the
   257  // RefreshToken to fetch a new ID token if required
   258  func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
   259  	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
   260  		return false, nil
   261  	}
   262  
   263  	newToken, newIDToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
   264  	if err != nil {
   265  		return false, err
   266  	}
   267  
   268  	// re-check that the user is in the proper google group(s)
   269  	if !p.ValidateGroup(s.Email) {
   270  		return false, fmt.Errorf("%s is no longer in the group(s)", s.Email)
   271  	}
   272  
   273  	origExpiration := s.ExpiresOn
   274  	s.AccessToken = newToken
   275  	s.IDToken = newIDToken
   276  	s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
   277  	log.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
   278  	return true, nil
   279  }
   280  
   281  func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, idToken string, expires time.Duration, err error) {
   282  	// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
   283  	params := url.Values{}
   284  	params.Add("client_id", p.ClientID)
   285  	params.Add("client_secret", p.ClientSecret)
   286  	params.Add("refresh_token", refreshToken)
   287  	params.Add("grant_type", "refresh_token")
   288  	var req *http.Request
   289  	req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
   290  	if err != nil {
   291  		return
   292  	}
   293  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   294  
   295  	resp, err := http.DefaultClient.Do(req)
   296  	if err != nil {
   297  		return
   298  	}
   299  	var body []byte
   300  	body, err = ioutil.ReadAll(resp.Body)
   301  	resp.Body.Close()
   302  	if err != nil {
   303  		return
   304  	}
   305  
   306  	if resp.StatusCode != 200 {
   307  		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
   308  		return
   309  	}
   310  
   311  	var data struct {
   312  		AccessToken string `json:"access_token"`
   313  		ExpiresIn   int64  `json:"expires_in"`
   314  		IDToken     string `json:"id_token"`
   315  	}
   316  	err = json.Unmarshal(body, &data)
   317  	if err != nil {
   318  		return
   319  	}
   320  	token = data.AccessToken
   321  	idToken = data.IDToken
   322  	expires = time.Duration(data.ExpiresIn) * time.Second
   323  	return
   324  }