github.com/versent/saml2aws@v2.17.0+incompatible/pkg/provider/onelogin/onelogin.go (about)

     1  package onelogin
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/url"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/pkg/errors"
    14  	"github.com/sirupsen/logrus"
    15  	"github.com/tidwall/gjson"
    16  	"github.com/versent/saml2aws/pkg/cfg"
    17  	"github.com/versent/saml2aws/pkg/creds"
    18  	"github.com/versent/saml2aws/pkg/prompter"
    19  	"github.com/versent/saml2aws/pkg/provider"
    20  )
    21  
    22  // MFA identifier constants.
    23  const (
    24  	IdentifierOneLoginProtectMfa = "OneLogin Protect"
    25  	IdentifierSmsMfa             = "OneLogin SMS"
    26  	IdentifierTotpMfa            = "Google Authenticator"
    27  
    28  	MessageMFARequired = "MFA is required for this user"
    29  	MessageSuccess     = "Success"
    30  	TypePending        = "pending"
    31  	TypeSuccess        = "success"
    32  )
    33  
    34  // ProviderName constant holds the name of the OneLogin IDP.
    35  const ProviderName = "OneLogin"
    36  
    37  var logger = logrus.WithField("provider", ProviderName)
    38  
    39  var (
    40  	supportedMfaOptions = map[string]string{
    41  		IdentifierOneLoginProtectMfa: "OLP",
    42  		IdentifierSmsMfa:             "SMS",
    43  		IdentifierTotpMfa:            "TOTP",
    44  	}
    45  )
    46  
    47  // Client is a wrapper representing a OneLogin SAML client.
    48  type Client struct {
    49  	// AppID represents the OneLogin connector id.
    50  	AppID string
    51  	// Client is the HTTP client for accessing the IDP provider's APIs.
    52  	Client *provider.HTTPClient
    53  	// A predefined MFA name.
    54  	MFA string
    55  	// Subdomain is the organisation subdomain in OneLogin.
    56  	Subdomain string
    57  }
    58  
    59  // AuthRequest represents an mfa OneLogin request.
    60  type AuthRequest struct {
    61  	AppID     string `json:"app_id"`
    62  	Password  string `json:"password"`
    63  	Subdomain string `json:"subdomain"`
    64  	Username  string `json:"username_or_email"`
    65  	IPAddress string `json:"ip_address,omitempty"`
    66  }
    67  
    68  // VerifyRequest represents an mfa verify request
    69  type VerifyRequest struct {
    70  	AppID       string `json:"app_id"`
    71  	DeviceID    string `json:"device_id"`
    72  	DoNotNotify bool   `json:"do_not_notify"`
    73  	OTPToken    string `json:"otp_token,omitempty"`
    74  	StateToken  string `json:"state_token"`
    75  }
    76  
    77  // New creates a new OneLogin client.
    78  func New(idpAccount *cfg.IDPAccount) (*Client, error) {
    79  	tr := provider.NewDefaultTransport(idpAccount.SkipVerify)
    80  	client, err := provider.NewHTTPClient(tr)
    81  	if err != nil {
    82  		return nil, errors.Wrap(err, "error building http client")
    83  	}
    84  	return &Client{AppID: idpAccount.AppID, Client: client, MFA: idpAccount.MFA, Subdomain: idpAccount.Subdomain}, nil
    85  }
    86  
    87  // Authenticate logs into OneLogin and returns a SAML response.
    88  func (c *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) {
    89  	providerURL, err := url.Parse(loginDetails.URL)
    90  	if err != nil {
    91  		return "", errors.Wrap(err, "error building providerURL")
    92  	}
    93  	host := providerURL.Host
    94  
    95  	logger.Debug("Generating OneLogin access token")
    96  	// request oAuth token required for working with OneLogin APIs
    97  	oauthToken, err := generateToken(c, loginDetails, host)
    98  	if err != nil {
    99  		return "", errors.Wrap(err, "failed to generate oauth token")
   100  	}
   101  
   102  	logger.Debug("Retrieved OneLogin OAuth token:", oauthToken)
   103  
   104  	authReq := AuthRequest{Username: loginDetails.Username, Password: loginDetails.Password, AppID: c.AppID, Subdomain: c.Subdomain}
   105  	var authBody bytes.Buffer
   106  	err = json.NewEncoder(&authBody).Encode(authReq)
   107  	if err != nil {
   108  		return "", errors.Wrap(err, "error encoding authreq")
   109  	}
   110  
   111  	authSubmitURL := fmt.Sprintf("https://%s/api/1/saml_assertion", host)
   112  
   113  	req, err := http.NewRequest("POST", authSubmitURL, &authBody)
   114  	if err != nil {
   115  		return "", errors.Wrap(err, "error building authentication request")
   116  	}
   117  
   118  	addContentHeaders(req)
   119  	addAuthHeader(req, oauthToken)
   120  
   121  	logger.Debug("Requesting SAML Assertion")
   122  
   123  	// request the SAML assertion. For more details check https://developers.onelogin.com/api-docs/1/saml-assertions/generate-saml-assertion
   124  	res, err := c.Client.Do(req)
   125  	if err != nil {
   126  		return "", errors.Wrap(err, "error retrieving auth response")
   127  	}
   128  	defer res.Body.Close()
   129  
   130  	body, err := ioutil.ReadAll(res.Body)
   131  	if err != nil {
   132  		return "", errors.Wrap(err, "error retrieving body from response")
   133  	}
   134  
   135  	resp := string(body)
   136  
   137  	logger.Debug("SAML Assertion response code:", res.StatusCode)
   138  	logger.Debug("SAML Assertion response body:", resp)
   139  
   140  	authError := gjson.Get(resp, "status.error").Bool()
   141  	authMessage := gjson.Get(resp, "status.message").String()
   142  	authType := gjson.Get(resp, "status.type").String()
   143  	if authError || authType != TypeSuccess {
   144  		return "", errors.New(authMessage)
   145  	}
   146  
   147  	authData := gjson.Get(resp, "data")
   148  	var samlAssertion string
   149  	switch authMessage {
   150  	// MFA not required
   151  	case MessageSuccess:
   152  		if authData.IsArray() {
   153  			return "", errors.New("invalid SAML assertion returned")
   154  		}
   155  		samlAssertion = authData.String()
   156  	case MessageMFARequired:
   157  		if !authData.IsArray() {
   158  			return "", errors.New("invalid MFA data returned")
   159  		}
   160  		logger.Debug("Verifying MFA")
   161  		samlAssertion, err = verifyMFA(c, oauthToken, c.AppID, resp)
   162  		if err != nil {
   163  			return "", errors.Wrap(err, "error verifying MFA")
   164  		}
   165  	default:
   166  		return "", errors.New("unexpected SAML assertion response")
   167  	}
   168  
   169  	return samlAssertion, nil
   170  }
   171  
   172  // generateToken is used to generate access token for all OneLogin APIs.
   173  // For more infor read https://developers.onelogin.com/api-docs/1/oauth20-tokens/generate-tokens-2
   174  func generateToken(oc *Client, loginDetails *creds.LoginDetails, host string) (string, error) {
   175  	oauthTokenURL := fmt.Sprintf("https://%s/auth/oauth2/v2/token", host)
   176  	req, err := http.NewRequest("POST", oauthTokenURL, strings.NewReader(`{"grant_type":"client_credentials"}`))
   177  	if err != nil {
   178  		return "", errors.Wrap(err, "error building oauth token request")
   179  	}
   180  
   181  	addContentHeaders(req)
   182  	req.SetBasicAuth(loginDetails.ClientID, loginDetails.ClientSecret)
   183  	res, err := oc.Client.Do(req)
   184  	if err != nil {
   185  		return "", errors.Wrap(err, "error retrieving oauth token response")
   186  	}
   187  
   188  	body, err := ioutil.ReadAll(res.Body)
   189  	if err != nil {
   190  		return "", errors.Wrap(err, "error reading oauth token response")
   191  	}
   192  	defer res.Body.Close()
   193  
   194  	return gjson.Get(string(body), "access_token").String(), nil
   195  }
   196  
   197  func addAuthHeader(r *http.Request, oauthToken string) {
   198  	r.Header.Add("Authorization", "bearer: "+oauthToken)
   199  }
   200  
   201  func addContentHeaders(r *http.Request) {
   202  	r.Header.Add("Content-Type", "application/json")
   203  	r.Header.Add("Accept", "application/json")
   204  }
   205  
   206  // verifyMFA is used to either prompt to user for one time password or request approval using push notification.
   207  // For more details check https://developers.onelogin.com/api-docs/1/saml-assertions/verify-factor
   208  func verifyMFA(oc *Client, oauthToken, appID, resp string) (string, error) {
   209  	stateToken := gjson.Get(resp, "data.0.state_token").String()
   210  	// choose an mfa option if there are multiple enabled
   211  	var option int
   212  	var mfaOptions []string
   213  	var preselected bool
   214  	for n, id := range gjson.Get(resp, "data.0.devices.#.device_type").Array() {
   215  		identifier := id.String()
   216  		if val, ok := supportedMfaOptions[identifier]; ok {
   217  			mfaOptions = append(mfaOptions, val)
   218  			// If there is pre-selected MFA option (thorugh the --mfa flag), then set MFA option index and break early.
   219  			if val == oc.MFA {
   220  				option = n
   221  				preselected = true
   222  				break
   223  			}
   224  		} else {
   225  			mfaOptions = append(mfaOptions, "UNSUPPORTED: "+identifier)
   226  		}
   227  	}
   228  	if !preselected && len(mfaOptions) > 1 {
   229  		option = prompter.Choose("Select which MFA option to use", mfaOptions)
   230  	}
   231  
   232  	factorID := gjson.Get(resp, fmt.Sprintf("data.0.devices.%d.device_id", option)).String()
   233  	callbackURL := gjson.Get(resp, "data.0.callback_url").String()
   234  	mfaIdentifer := gjson.Get(resp, fmt.Sprintf("data.0.devices.%d.device_type", option)).String()
   235  	mfaDeviceID := gjson.Get(resp, fmt.Sprintf("data.0.devices.%d.device_id", option)).String()
   236  
   237  	logger.WithField("factorID", factorID).WithField("callbackURL", callbackURL).WithField("mfaIdentifer", mfaIdentifer).Debug("MFA")
   238  
   239  	if _, ok := supportedMfaOptions[mfaIdentifer]; !ok {
   240  		return "", errors.New("unsupported mfa provider")
   241  	}
   242  
   243  	// TOTP MFA doesn't need additional request (e.g. to send SMS or a push notification etc) since the user can generate the code using their MFA app of choice.
   244  	if mfaIdentifer != IdentifierTotpMfa {
   245  		var verifyBody bytes.Buffer
   246  		err := json.NewEncoder(&verifyBody).Encode(VerifyRequest{AppID: appID, DeviceID: mfaDeviceID, StateToken: stateToken})
   247  		if err != nil {
   248  			return "", errors.Wrap(err, "error encoding verifyReq")
   249  		}
   250  
   251  		req, err := http.NewRequest("POST", callbackURL, &verifyBody)
   252  		if err != nil {
   253  			return "", errors.Wrap(err, "error building verify request")
   254  		}
   255  
   256  		addContentHeaders(req)
   257  		addAuthHeader(req, oauthToken)
   258  		res, err := oc.Client.Do(req)
   259  		if err != nil {
   260  			return "", errors.Wrap(err, "error retrieving verify response")
   261  		}
   262  
   263  		body, err := ioutil.ReadAll(res.Body)
   264  		if err != nil {
   265  			return "", errors.Wrap(err, "error retrieving body from response")
   266  		}
   267  		resp = string(body)
   268  		if gjson.Get(resp, "status.error").Bool() {
   269  			msg := gjson.Get(resp, "status.message").String()
   270  			return "", errors.New(msg)
   271  		}
   272  	}
   273  
   274  	switch mfaIdentifer {
   275  	case IdentifierSmsMfa, IdentifierTotpMfa:
   276  		verifyCode := prompter.StringRequired("Enter verification code")
   277  		var verifyBody bytes.Buffer
   278  		json.NewEncoder(&verifyBody).Encode(VerifyRequest{AppID: appID, DeviceID: mfaDeviceID, StateToken: stateToken, OTPToken: verifyCode})
   279  		req, err := http.NewRequest("POST", callbackURL, &verifyBody)
   280  		if err != nil {
   281  			return "", errors.Wrap(err, "error building token post request")
   282  		}
   283  
   284  		addContentHeaders(req)
   285  		addAuthHeader(req, oauthToken)
   286  		res, err := oc.Client.Do(req)
   287  		if err != nil {
   288  			return "", errors.Wrap(err, "error retrieving token post response")
   289  		}
   290  
   291  		body, err := ioutil.ReadAll(res.Body)
   292  		if err != nil {
   293  			return "", errors.Wrap(err, "error retrieving body from response")
   294  		}
   295  
   296  		resp = string(body)
   297  
   298  		message := gjson.Get(resp, "status.message").String()
   299  		if gjson.Get(resp, "status.error").Bool() {
   300  			return "", errors.New(message)
   301  		}
   302  
   303  		return gjson.Get(resp, "data").String(), nil
   304  
   305  	case IdentifierOneLoginProtectMfa:
   306  		// set the body payload to disable further push notifications (i.e. set do_not_notify to true)
   307  		// https://developers.onelogin.com/api-docs/1/saml-assertions/verify-factor
   308  		var verifyBody bytes.Buffer
   309  		err := json.NewEncoder(&verifyBody).Encode(VerifyRequest{AppID: appID, DeviceID: mfaDeviceID, DoNotNotify: true, StateToken: stateToken})
   310  		if err != nil {
   311  			return "", errors.New("error encoding verify MFA request body")
   312  		}
   313  		req, err := http.NewRequest("POST", callbackURL, &verifyBody)
   314  		if err != nil {
   315  			return "", errors.Wrap(err, "error building token post request")
   316  		}
   317  
   318  		addContentHeaders(req)
   319  		addAuthHeader(req, oauthToken)
   320  
   321  		fmt.Printf("\nWaiting for approval, please check your OneLogin Protect app ...")
   322  		started := time.Now()
   323  		// loop until success, error, or timeout
   324  		for {
   325  			if time.Since(started) > time.Minute {
   326  				fmt.Println(" Timeout")
   327  				return "", errors.New("User did not accept MFA in time")
   328  			}
   329  
   330  			logger.Debug("Verifying with OneLogin Protect")
   331  			res, err := oc.Client.Do(req)
   332  			if err != nil {
   333  				return "", errors.Wrap(err, "error retrieving verify response")
   334  			}
   335  
   336  			body, err := ioutil.ReadAll(res.Body)
   337  			if err != nil {
   338  				return "", errors.Wrap(err, "error retrieving body from response")
   339  			}
   340  
   341  			message := gjson.Get(string(body), "status.message").String()
   342  
   343  			// on 'error' status
   344  			if gjson.Get(string(body), "status.error").Bool() {
   345  				return "", errors.New(message)
   346  			}
   347  
   348  			switch gjson.Get(string(body), "status.type").String() {
   349  			case TypePending:
   350  				time.Sleep(time.Second)
   351  				fmt.Print(".")
   352  
   353  			case TypeSuccess:
   354  				fmt.Println(" Approved")
   355  				return gjson.Get(string(body), "data").String(), nil
   356  
   357  			default:
   358  				fmt.Println(" Error:")
   359  				return "", errors.New("unsupported response from OneLogin, please raise ticket with saml2aws")
   360  			}
   361  		}
   362  	}
   363  
   364  	// catch all
   365  	return "", errors.New("no mfa options provided")
   366  }