github.com/pusher/oauth2_proxy@v3.2.0+incompatible/providers/google.go (about) 1 package providers 2 3 import ( 4 "bytes" 5 "encoding/base64" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "io" 10 "io/ioutil" 11 "log" 12 "net/http" 13 "net/url" 14 "strings" 15 "time" 16 17 "golang.org/x/oauth2" 18 "golang.org/x/oauth2/google" 19 admin "google.golang.org/api/admin/directory/v1" 20 "google.golang.org/api/googleapi" 21 ) 22 23 // GoogleProvider represents an Google based Identity Provider 24 type GoogleProvider struct { 25 *ProviderData 26 RedeemRefreshURL *url.URL 27 // GroupValidator is a function that determines if the passed email is in 28 // the configured Google group. 29 GroupValidator func(string) bool 30 } 31 32 // NewGoogleProvider initiates a new GoogleProvider 33 func NewGoogleProvider(p *ProviderData) *GoogleProvider { 34 p.ProviderName = "Google" 35 if p.LoginURL.String() == "" { 36 p.LoginURL = &url.URL{Scheme: "https", 37 Host: "accounts.google.com", 38 Path: "/o/oauth2/auth", 39 // to get a refresh token. see https://developers.google.com/identity/protocols/OAuth2WebServer#offline 40 RawQuery: "access_type=offline", 41 } 42 } 43 if p.RedeemURL.String() == "" { 44 p.RedeemURL = &url.URL{Scheme: "https", 45 Host: "www.googleapis.com", 46 Path: "/oauth2/v3/token"} 47 } 48 if p.ValidateURL.String() == "" { 49 p.ValidateURL = &url.URL{Scheme: "https", 50 Host: "www.googleapis.com", 51 Path: "/oauth2/v1/tokeninfo"} 52 } 53 if p.Scope == "" { 54 p.Scope = "profile email" 55 } 56 57 return &GoogleProvider{ 58 ProviderData: p, 59 // Set a default GroupValidator to just always return valid (true), it will 60 // be overwritten if we configured a Google group restriction. 61 GroupValidator: func(email string) bool { 62 return true 63 }, 64 } 65 } 66 67 func emailFromIDToken(idToken string) (string, error) { 68 69 // id_token is a base64 encode ID token payload 70 // https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo 71 jwt := strings.Split(idToken, ".") 72 jwtData := strings.TrimSuffix(jwt[1], "=") 73 b, err := base64.RawURLEncoding.DecodeString(jwtData) 74 if err != nil { 75 return "", err 76 } 77 78 var email struct { 79 Email string `json:"email"` 80 EmailVerified bool `json:"email_verified"` 81 } 82 err = json.Unmarshal(b, &email) 83 if err != nil { 84 return "", err 85 } 86 if email.Email == "" { 87 return "", errors.New("missing email") 88 } 89 if !email.EmailVerified { 90 return "", fmt.Errorf("email %s not listed as verified", email.Email) 91 } 92 return email.Email, nil 93 } 94 95 // Redeem exchanges the OAuth2 authentication token for an ID token 96 func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { 97 if code == "" { 98 err = errors.New("missing code") 99 return 100 } 101 102 params := url.Values{} 103 params.Add("redirect_uri", redirectURL) 104 params.Add("client_id", p.ClientID) 105 params.Add("client_secret", p.ClientSecret) 106 params.Add("code", code) 107 params.Add("grant_type", "authorization_code") 108 var req *http.Request 109 req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) 110 if err != nil { 111 return 112 } 113 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 114 115 resp, err := http.DefaultClient.Do(req) 116 if err != nil { 117 return 118 } 119 var body []byte 120 body, err = ioutil.ReadAll(resp.Body) 121 resp.Body.Close() 122 if err != nil { 123 return 124 } 125 126 if resp.StatusCode != 200 { 127 err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) 128 return 129 } 130 131 var jsonResponse struct { 132 AccessToken string `json:"access_token"` 133 RefreshToken string `json:"refresh_token"` 134 ExpiresIn int64 `json:"expires_in"` 135 IDToken string `json:"id_token"` 136 } 137 err = json.Unmarshal(body, &jsonResponse) 138 if err != nil { 139 return 140 } 141 var email string 142 email, err = emailFromIDToken(jsonResponse.IDToken) 143 if err != nil { 144 return 145 } 146 s = &SessionState{ 147 AccessToken: jsonResponse.AccessToken, 148 IDToken: jsonResponse.IDToken, 149 ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), 150 RefreshToken: jsonResponse.RefreshToken, 151 Email: email, 152 } 153 return 154 } 155 156 // SetGroupRestriction configures the GoogleProvider to restrict access to the 157 // specified group(s). AdminEmail has to be an administrative email on the domain that is 158 // checked. CredentialsFile is the path to a json file containing a Google service 159 // account credentials. 160 func (p *GoogleProvider) SetGroupRestriction(groups []string, adminEmail string, credentialsReader io.Reader) { 161 adminService := getAdminService(adminEmail, credentialsReader) 162 p.GroupValidator = func(email string) bool { 163 return userInGroup(adminService, groups, email) 164 } 165 } 166 167 func getAdminService(adminEmail string, credentialsReader io.Reader) *admin.Service { 168 data, err := ioutil.ReadAll(credentialsReader) 169 if err != nil { 170 log.Fatal("can't read Google credentials file:", err) 171 } 172 conf, err := google.JWTConfigFromJSON(data, admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope) 173 if err != nil { 174 log.Fatal("can't load Google credentials file:", err) 175 } 176 conf.Subject = adminEmail 177 178 client := conf.Client(oauth2.NoContext) 179 adminService, err := admin.New(client) 180 if err != nil { 181 log.Fatal(err) 182 } 183 return adminService 184 } 185 186 func userInGroup(service *admin.Service, groups []string, email string) bool { 187 user, err := fetchUser(service, email) 188 if err != nil { 189 log.Printf("error fetching user: %v", err) 190 return false 191 } 192 id := user.Id 193 custID := user.CustomerId 194 195 for _, group := range groups { 196 members, err := fetchGroupMembers(service, group) 197 if err != nil { 198 if err, ok := err.(*googleapi.Error); ok && err.Code == 404 { 199 log.Printf("error fetching members for group %s: group does not exist", group) 200 } else { 201 log.Printf("error fetching group members: %v", err) 202 return false 203 } 204 } 205 206 for _, member := range members { 207 switch member.Type { 208 case "CUSTOMER": 209 if member.Id == custID { 210 return true 211 } 212 case "USER": 213 if member.Id == id { 214 return true 215 } 216 } 217 } 218 } 219 return false 220 } 221 222 func fetchUser(service *admin.Service, email string) (*admin.User, error) { 223 user, err := service.Users.Get(email).Do() 224 return user, err 225 } 226 227 func fetchGroupMembers(service *admin.Service, group string) ([]*admin.Member, error) { 228 members := []*admin.Member{} 229 pageToken := "" 230 for { 231 req := service.Members.List(group) 232 if pageToken != "" { 233 req.PageToken(pageToken) 234 } 235 r, err := req.Do() 236 if err != nil { 237 return nil, err 238 } 239 for _, member := range r.Members { 240 members = append(members, member) 241 } 242 if r.NextPageToken == "" { 243 break 244 } 245 pageToken = r.NextPageToken 246 } 247 return members, nil 248 } 249 250 // ValidateGroup validates that the provided email exists in the configured Google 251 // group(s). 252 func (p *GoogleProvider) ValidateGroup(email string) bool { 253 return p.GroupValidator(email) 254 } 255 256 // RefreshSessionIfNeeded checks if the session has expired and uses the 257 // RefreshToken to fetch a new ID token if required 258 func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { 259 if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { 260 return false, nil 261 } 262 263 newToken, newIDToken, duration, err := p.redeemRefreshToken(s.RefreshToken) 264 if err != nil { 265 return false, err 266 } 267 268 // re-check that the user is in the proper google group(s) 269 if !p.ValidateGroup(s.Email) { 270 return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) 271 } 272 273 origExpiration := s.ExpiresOn 274 s.AccessToken = newToken 275 s.IDToken = newIDToken 276 s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second) 277 log.Printf("refreshed access token %s (expired on %s)", s, origExpiration) 278 return true, nil 279 } 280 281 func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, idToken string, expires time.Duration, err error) { 282 // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh 283 params := url.Values{} 284 params.Add("client_id", p.ClientID) 285 params.Add("client_secret", p.ClientSecret) 286 params.Add("refresh_token", refreshToken) 287 params.Add("grant_type", "refresh_token") 288 var req *http.Request 289 req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) 290 if err != nil { 291 return 292 } 293 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 294 295 resp, err := http.DefaultClient.Do(req) 296 if err != nil { 297 return 298 } 299 var body []byte 300 body, err = ioutil.ReadAll(resp.Body) 301 resp.Body.Close() 302 if err != nil { 303 return 304 } 305 306 if resp.StatusCode != 200 { 307 err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) 308 return 309 } 310 311 var data struct { 312 AccessToken string `json:"access_token"` 313 ExpiresIn int64 `json:"expires_in"` 314 IDToken string `json:"id_token"` 315 } 316 err = json.Unmarshal(body, &data) 317 if err != nil { 318 return 319 } 320 token = data.AccessToken 321 idToken = data.IDToken 322 expires = time.Duration(data.ExpiresIn) * time.Second 323 return 324 }