github.com/greenpau/go-authcrunch@v1.1.4/pkg/idp/oauth/authenticate.go (about)

     1  // Copyright 2022 Paul Greenberg greenpau@outlook.com
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package oauth
    16  
    17  import (
    18  	"encoding/json"
    19  	"io/ioutil"
    20  	"net/http"
    21  	"net/url"
    22  	"path"
    23  	"strconv"
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/greenpau/go-authcrunch/pkg/errors"
    28  	"github.com/greenpau/go-authcrunch/pkg/requests"
    29  	"github.com/greenpau/go-authcrunch/pkg/util"
    30  
    31  	"github.com/google/uuid"
    32  	"go.uber.org/zap"
    33  )
    34  
    35  // Authenticate performs authentication.
    36  func (b *IdentityProvider) Authenticate(r *requests.Request) error {
    37  	reqPath := r.Upstream.BaseURL + path.Join(r.Upstream.BasePath, r.Upstream.Method, r.Upstream.Realm)
    38  	r.Response.Code = http.StatusBadRequest
    39  
    40  	var accessTokenExists, idTokenExists, codeExists, stateExists, errorExists, loginHintExists, additionalScopesExists bool
    41  	var reqParamsAccessToken, reqParamsIDToken, reqParamsState, reqParamsCode, reqParamsError, reqParamsLoginHint, additionalScopes string
    42  	reqParams := r.Upstream.Request.URL.Query()
    43  	if _, exists := reqParams["access_token"]; exists {
    44  		accessTokenExists = true
    45  		reqParamsAccessToken = reqParams["access_token"][0]
    46  	}
    47  	if _, exists := reqParams["id_token"]; exists {
    48  		idTokenExists = true
    49  		reqParamsIDToken = reqParams["id_token"][0]
    50  	}
    51  	if _, exists := reqParams["code"]; exists {
    52  		codeExists = true
    53  		reqParamsCode = reqParams["code"][0]
    54  	}
    55  	if _, exists := reqParams["state"]; exists {
    56  		stateExists = true
    57  		reqParamsState = reqParams["state"][0]
    58  	}
    59  	if _, exists := reqParams["error"]; exists {
    60  		errorExists = true
    61  		reqParamsError = reqParams["error"][0]
    62  	}
    63  	if _, exists := reqParams["login_hint"]; exists {
    64  		loginHintExists = true
    65  		reqParamsLoginHint = reqParams["login_hint"][0]
    66  	}
    67  	if _, exists := reqParams["additional_scopes"]; exists {
    68  		additionalScopesExists = true
    69  		additionalScopes = reqParams["additional_scopes"][0]
    70  	}
    71  
    72  	if stateExists || errorExists || codeExists || accessTokenExists {
    73  		b.logger.Debug(
    74  			"received OAuth 2.0 response",
    75  			zap.String("session_id", r.Upstream.SessionID),
    76  			zap.String("request_id", r.ID),
    77  			zap.Any("params", reqParams),
    78  		)
    79  		if errorExists {
    80  			if v, exists := reqParams["error_description"]; exists {
    81  				return errors.ErrIdentityProviderOauthAuthorizationFailedDetailed.WithArgs(reqParamsError, v[0])
    82  			}
    83  			return errors.ErrIdentityProviderOauthAuthorizationFailed.WithArgs(reqParamsError)
    84  		}
    85  		switch {
    86  		case codeExists && stateExists:
    87  			// Received Authorization Code
    88  			if b.state.exists(reqParamsState) {
    89  				b.state.addCode(reqParamsState, reqParamsCode)
    90  			} else {
    91  				return errors.ErrIdentityProviderOauthAuthorizationStateNotFound
    92  			}
    93  			b.logger.Debug(
    94  				"received OAuth 2.0 code and state from the authorization server",
    95  				zap.String("session_id", r.Upstream.SessionID),
    96  				zap.String("request_id", r.ID),
    97  				zap.String("state", reqParamsState),
    98  				zap.String("code", reqParamsCode),
    99  			)
   100  
   101  			reqRedirectURI := reqPath + "/authorization-code-callback"
   102  			var accessToken map[string]interface{}
   103  			var err error
   104  			switch b.config.Driver {
   105  			case "facebook":
   106  				accessToken, err = b.fetchFacebookAccessToken(reqRedirectURI, reqParamsState, reqParamsCode)
   107  			default:
   108  				accessToken, err = b.fetchAccessToken(reqRedirectURI, reqParamsState, reqParamsCode)
   109  			}
   110  			if err != nil {
   111  				b.logger.Debug(
   112  					"failed fetching OAuth 2.0 access token from the authorization server",
   113  					zap.String("session_id", r.Upstream.SessionID),
   114  					zap.String("request_id", r.ID),
   115  					zap.Error(err),
   116  				)
   117  				return errors.ErrIdentityProviderOauthFetchAccessTokenFailed.WithArgs(err)
   118  			}
   119  			b.logger.Debug(
   120  				"received OAuth 2.0 authorization server access token",
   121  				zap.String("request_id", r.ID),
   122  				zap.Any("token", accessToken),
   123  			)
   124  
   125  			var m map[string]interface{}
   126  
   127  			switch b.config.Driver {
   128  			case "github", "gitlab", "facebook", "discord", "linkedin":
   129  				m, err = b.fetchClaims(accessToken)
   130  				if err != nil {
   131  					return errors.ErrIdentityProviderOauthFetchClaimsFailed.WithArgs(err)
   132  				}
   133  			default:
   134  				m, err = b.validateAccessToken(reqParamsState, accessToken)
   135  				if err != nil {
   136  					return errors.ErrIdentityProviderOauthValidateAccessTokenFailed.WithArgs(err)
   137  				}
   138  			}
   139  
   140  			// Fetch user info.
   141  			if err := b.fetchUserInfo(accessToken, m); err != nil {
   142  				b.logger.Debug(
   143  					"failed fetching user info",
   144  					zap.String("request_id", r.ID),
   145  					zap.Error(err),
   146  				)
   147  			}
   148  
   149  			// Fetch subsequent user info, e.g. user groups.
   150  			if err := b.fetchUserGroups(accessToken, m); err != nil {
   151  				b.logger.Debug(
   152  					"failed fetching user groups",
   153  					zap.String("request_id", r.ID),
   154  					zap.Error(err),
   155  				)
   156  			}
   157  
   158  			if b.config.IdentityTokenCookieEnabled {
   159  				if v, exists := accessToken["id_token"]; exists {
   160  					r.Response.IdentityTokenCookie.Enabled = true
   161  					r.Response.IdentityTokenCookie.Name = b.config.IdentityTokenCookieName
   162  					r.Response.IdentityTokenCookie.Payload = v.(string)
   163  				}
   164  			}
   165  
   166  			r.Response.Payload = m
   167  			r.Response.Code = http.StatusOK
   168  			b.logger.Debug(
   169  				"decoded claims from OAuth 2.0 authorization server access token",
   170  				zap.String("request_id", r.ID),
   171  				zap.Any("claims", m),
   172  			)
   173  			return nil
   174  		case idTokenExists && accessTokenExists:
   175  			accessToken := map[string]interface{}{
   176  				"access_token": reqParamsAccessToken,
   177  				"id_token":     reqParamsIDToken,
   178  			}
   179  			m, err := b.validateAccessToken(reqParamsState, accessToken)
   180  			if err != nil {
   181  				return errors.ErrIdentityProviderOauthValidateAccessTokenFailed.WithArgs(err)
   182  			}
   183  
   184  			r.Response.Payload = m
   185  			r.Response.Code = http.StatusOK
   186  
   187  			if b.config.IdentityTokenCookieEnabled {
   188  				r.Response.IdentityTokenCookie.Enabled = true
   189  				r.Response.IdentityTokenCookie.Name = b.config.IdentityTokenCookieName
   190  				r.Response.IdentityTokenCookie.Payload = reqParamsIDToken
   191  			}
   192  
   193  			b.logger.Debug(
   194  				"decoded claims from OAuth 2.0 authorization server access token",
   195  				zap.String("request_id", r.ID),
   196  				zap.Any("claims", m),
   197  			)
   198  			return nil
   199  		}
   200  		return errors.ErrIdentityProviderOauthResponseProcessingFailed
   201  	}
   202  	r.Response.Code = http.StatusFound
   203  	state := uuid.New().String()
   204  	nonce := util.GetRandomString(32)
   205  	params := url.Values{}
   206  	// CSRF Protection
   207  	params.Set("state", state)
   208  	if !b.disableNonce {
   209  		// Server Side-Replay Protection
   210  		params.Set("nonce", nonce)
   211  	}
   212  	if !b.disableScope {
   213  		scopes := b.config.Scopes
   214  		if additionalScopesExists {
   215  			scopes = append(scopes, strings.Split(additionalScopes, " ")...)
   216  		}
   217  		params.Set("scope", strings.Join(scopes, " "))
   218  	}
   219  
   220  	if b.config.JsCallbackEnabled {
   221  		params.Set("redirect_uri", reqPath+"/authorization-code-js-callback")
   222  	} else {
   223  		params.Set("redirect_uri", reqPath+"/authorization-code-callback")
   224  	}
   225  
   226  	if !b.disableResponseType {
   227  		params.Set("response_type", strings.Join(b.config.ResponseType, " "))
   228  	}
   229  	if loginHintExists {
   230  		params.Set("login_hint", reqParamsLoginHint)
   231  	}
   232  
   233  	params.Set("client_id", b.config.ClientID)
   234  
   235  	r.Response.RedirectURL = b.authorizationURL + "?" + params.Encode()
   236  
   237  	b.state.add(state, nonce)
   238  	b.logger.Debug(
   239  		"redirecting to OAuth 2.0 endpoint",
   240  		zap.String("request_id", r.ID),
   241  		zap.String("redirect_url", r.Response.RedirectURL),
   242  	)
   243  	return nil
   244  }
   245  
   246  func (b *IdentityProvider) fetchAccessToken(redirectURI, state, code string) (map[string]interface{}, error) {
   247  	params := url.Values{}
   248  	params.Set("client_id", b.config.ClientID)
   249  	params.Set("client_secret", b.config.ClientSecret)
   250  	if !b.disablePassGrantType {
   251  		params.Set("grant_type", "authorization_code")
   252  	}
   253  	params.Set("state", state)
   254  	params.Set("code", code)
   255  	params.Set("redirect_uri", redirectURI)
   256  
   257  	cli := &http.Client{
   258  		Timeout: time.Second * 10,
   259  	}
   260  
   261  	cli, err := b.newBrowser()
   262  	if err != nil {
   263  		return nil, err
   264  	}
   265  
   266  	req, err := http.NewRequest("POST", b.tokenURL, strings.NewReader(params.Encode()))
   267  	if err != nil {
   268  		return nil, err
   269  	}
   270  
   271  	// Adjust !!!
   272  	if b.enableAcceptHeader {
   273  		req.Header.Set("Accept", "application/json")
   274  	}
   275  
   276  	req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
   277  	req.Header.Add("Content-Length", strconv.Itoa(len(params.Encode())))
   278  
   279  	resp, err := cli.Do(req)
   280  	if err != nil {
   281  		return nil, err
   282  	}
   283  
   284  	respBody, err := ioutil.ReadAll(resp.Body)
   285  	resp.Body.Close()
   286  	if err != nil {
   287  		return nil, err
   288  	}
   289  
   290  	b.logger.Debug(
   291  		"OAuth 2.0 access token response received",
   292  		zap.Any("body", respBody),
   293  		zap.String("redirect_uri", redirectURI),
   294  	)
   295  
   296  	data := make(map[string]interface{})
   297  	if err := json.Unmarshal(respBody, &data); err != nil {
   298  		return nil, err
   299  	}
   300  
   301  	b.logger.Debug(
   302  		"OAuth 2.0 access token response decoded",
   303  		zap.Any("body", data),
   304  	)
   305  
   306  	if _, exists := data["error"]; exists {
   307  		if v, exists := data["error_description"]; exists {
   308  			return nil, errors.ErrIdentityProviderOauthGetAccessTokenFailedDetailed.WithArgs(data["error"].(string), v.(string))
   309  		}
   310  		switch data["error"].(type) {
   311  		case string:
   312  			return nil, errors.ErrIdentityProviderOauthGetAccessTokenFailed.WithArgs(data["error"].(string))
   313  		default:
   314  			return nil, errors.ErrIdentityProviderOauthGetAccessTokenFailed.WithArgs(data["error"])
   315  		}
   316  	}
   317  
   318  	for k := range b.requiredTokenFields {
   319  		if _, exists := data[k]; !exists {
   320  			return nil, errors.ErrIdentityProviderAuthorizationServerResponseFieldNotFound.WithArgs(k)
   321  		}
   322  	}
   323  	return data, nil
   324  }
   325  
   326  func (b *IdentityProvider) fetchFacebookAccessToken(redirectURI, state, code string) (map[string]interface{}, error) {
   327  	params := url.Values{}
   328  	params.Set("client_id", b.config.ClientID)
   329  	params.Set("client_secret", b.config.ClientSecret)
   330  	params.Set("code", code)
   331  	params.Set("redirect_uri", redirectURI)
   332  
   333  	cli := &http.Client{
   334  		Timeout: time.Second * 10,
   335  	}
   336  
   337  	cli, err := b.newBrowser()
   338  	if err != nil {
   339  		return nil, err
   340  	}
   341  
   342  	req, err := http.NewRequest("GET", b.tokenURL, nil)
   343  	if err != nil {
   344  		return nil, err
   345  	}
   346  
   347  	req.URL.RawQuery = params.Encode()
   348  
   349  	// Adjust !!!
   350  	if b.enableAcceptHeader {
   351  		req.Header.Set("Accept", "application/json")
   352  	}
   353  
   354  	resp, err := cli.Do(req)
   355  	if err != nil {
   356  		return nil, err
   357  	}
   358  
   359  	respBody, err := ioutil.ReadAll(resp.Body)
   360  	resp.Body.Close()
   361  	if err != nil {
   362  		return nil, err
   363  	}
   364  	b.logger.Debug(
   365  		"OAuth 2.0 access token response received",
   366  		zap.Any("body", respBody),
   367  	)
   368  
   369  	data := make(map[string]interface{})
   370  	if err := json.Unmarshal(respBody, &data); err != nil {
   371  		return nil, err
   372  	}
   373  	if _, exists := data["error"]; exists {
   374  		if v, exists := data["error_description"]; exists {
   375  			return nil, errors.ErrIdentityProviderOauthGetAccessTokenFailedDetailed.WithArgs(data["error"].(string), v.(string))
   376  		}
   377  		switch data["error"].(type) {
   378  		case string:
   379  			return nil, errors.ErrIdentityProviderOauthGetAccessTokenFailed.WithArgs(data["error"].(string))
   380  		default:
   381  			return nil, errors.ErrIdentityProviderOauthGetAccessTokenFailed.WithArgs(data["error"])
   382  		}
   383  	}
   384  
   385  	for k := range b.requiredTokenFields {
   386  		if _, exists := data[k]; !exists {
   387  			return nil, errors.ErrIdentityProviderAuthorizationServerResponseFieldNotFound.WithArgs(k)
   388  		}
   389  	}
   390  	return data, nil
   391  }