github.com/Venafi/vcert/v5@v5.10.2/pkg/venafi/firefly/identityProviderServer_test.go (about)

     1  package firefly
     2  
     3  import (
     4  	"encoding/json"
     5  	"log"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"strings"
     9  )
    10  
    11  const (
    12  	TestingClientID              = "1234567890"
    13  	TestingClientSecret          = "my_secret"
    14  	TestingUserName              = "my_name"
    15  	TestingUserPassword          = "my_password"
    16  	TestingDeviceCode            = "my_device_code"
    17  	TestingClientIDAuthPending   = "123"
    18  	TestingDeviceAuthPending     = "device_code_pending"
    19  	TestingClientIDSlowDown      = "456"
    20  	TestingDeviceSlowDown        = "device_code_slow_down"
    21  	TestingClientIDAccessDenied  = "789"
    22  	TestingDeviceAccessDenied    = "device_code_access_denied"
    23  	TestingClientIDExpiredToken  = "012"
    24  	TestingDeviceExpiredToken    = "device_code_expired_token"
    25  	TestingDeviceVerificationUri = "my_device_uri"
    26  	TestingScope                 = "my_scope"
    27  	TestingAudience              = "my_audience"
    28  	TestingAccessToken           = "my_access_token"
    29  )
    30  
    31  var (
    32  	authPendingCount = 0
    33  	slowDownCount    = 0
    34  )
    35  
    36  func newIdentityProviderMockServer() *IdentityProviderMockServer {
    37  	tokenPath := "/token"
    38  	devicePath := "/device"
    39  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    40  		switch strings.TrimSpace(r.URL.Path) {
    41  		case tokenPath:
    42  			processAccessTokenRequest(w, r)
    43  		case devicePath:
    44  			processDeviceCodeRequest(w, r)
    45  		default:
    46  			http.NotFoundHandler().ServeHTTP(w, r)
    47  		}
    48  	}))
    49  	//creating and returning the idp server
    50  	return &IdentityProviderMockServer{
    51  		server:     server,
    52  		idpURL:     server.URL,
    53  		tokenPath:  tokenPath,
    54  		devicePath: devicePath,
    55  	}
    56  }
    57  
    58  type IdentityProviderMockServer struct {
    59  	server     *httptest.Server
    60  	idpURL     string
    61  	tokenPath  string
    62  	devicePath string
    63  }
    64  
    65  type AccessTokenRequest struct {
    66  	grantType    string `json:"grant_type"`
    67  	clientId     string `json:"client_id"`
    68  	clientSecret string `json:"client_secret,omitempty"`
    69  	username     string `json:"username,omitempty"`
    70  	password     string `json:"password,omitempty"`
    71  	deviceCode   string `json:"device_code"`
    72  	scope        string `json:"scope"`
    73  	audience     string `json:"audience,omitempty"`
    74  }
    75  
    76  type AccessTokenResponse struct {
    77  	TokenType    string `json:"token_type"`
    78  	AccessToken  string `json:"access_token"`
    79  	RefreshToken string `json:"refresh_token"`
    80  	ExpiresIn    int32  `json:"expires_in"`
    81  	Scope        string `json:"scope"`
    82  }
    83  
    84  func processAccessTokenRequest(w http.ResponseWriter, r *http.Request) {
    85  	accessTokenRequest, err := parseAccessTokenRequest(r)
    86  	if err != nil {
    87  		writeError(w, http.StatusBadRequest, err.Error(), "")
    88  		return
    89  	}
    90  
    91  	if !validateAccessTokenRequest(w, accessTokenRequest) {
    92  		return
    93  	}
    94  
    95  	//Headers must be set before the status and the body are written to the response.
    96  	w.Header().Set("Content-Type", "application/json")
    97  	w.WriteHeader(http.StatusOK)
    98  
    99  	accessTokenResponse := AccessTokenResponse{
   100  		TokenType:    "Bearer",
   101  		AccessToken:  TestingAccessToken,
   102  		RefreshToken: "",
   103  		ExpiresIn:    120, //seconds
   104  		Scope:        TestingScope,
   105  	}
   106  
   107  	jsonResp, err := json.Marshal(accessTokenResponse)
   108  	if err != nil {
   109  		log.Fatalf("Error happened in JSON marshal. Err: %s", err)
   110  	}
   111  	w.Write(jsonResp)
   112  }
   113  
   114  // validateAccessTokenRequest returns true if the request is valid, else return false
   115  func validateAccessTokenRequest(w http.ResponseWriter, accessTokenRequest AccessTokenRequest) bool {
   116  	if accessTokenRequest.clientId == "" {
   117  		writeError(w, http.StatusBadRequest, "Status Bad Request", "The client_id is missing")
   118  		return false
   119  	}
   120  
   121  	switch accessTokenRequest.grantType {
   122  	case "client_credentials":
   123  		if accessTokenRequest.clientId != TestingClientID {
   124  			writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The client_id is not valid")
   125  			return false
   126  		}
   127  
   128  		if accessTokenRequest.clientSecret == "" {
   129  			writeError(w, http.StatusBadRequest, "Status Bad Request", "The client_secret is missing")
   130  			return false
   131  		}
   132  
   133  		if accessTokenRequest.clientSecret != TestingClientSecret {
   134  			writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The client_secret is not valid")
   135  			return false
   136  		}
   137  
   138  		if accessTokenRequest.scope == "" {
   139  			writeError(w, http.StatusBadRequest, "Status Bad Request", "The scope is missing")
   140  			return false
   141  		}
   142  
   143  		if accessTokenRequest.scope != TestingScope {
   144  			writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The scope is not valid")
   145  			return false
   146  		}
   147  	case "password":
   148  		if accessTokenRequest.clientId != TestingClientID {
   149  			writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The client_id is not valid")
   150  			return false
   151  		}
   152  
   153  		if accessTokenRequest.username == "" {
   154  			writeError(w, http.StatusBadRequest, "Status Bad Request", "The username is missing")
   155  			return false
   156  		}
   157  
   158  		if accessTokenRequest.username != TestingUserName {
   159  			writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The username is not valid")
   160  			return false
   161  		}
   162  
   163  		if accessTokenRequest.password == "" {
   164  			writeError(w, http.StatusBadRequest, "Status Bad Request", "The password is missing")
   165  			return false
   166  		}
   167  
   168  		if accessTokenRequest.password != TestingUserPassword {
   169  			writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The password is not valid")
   170  			return false
   171  		}
   172  
   173  		if accessTokenRequest.scope == "" {
   174  			writeError(w, http.StatusBadRequest, "Status Bad Request", "The scope is missing")
   175  			return false
   176  		}
   177  
   178  		if accessTokenRequest.scope != TestingScope {
   179  			writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The scope is not valid")
   180  			return false
   181  		}
   182  	case "urn:ietf:params:oauth:grant-type:device_code":
   183  		if accessTokenRequest.deviceCode == "" {
   184  			writeError(w, http.StatusBadRequest, "Status Bad Request", "The device_code is missing")
   185  			return false
   186  		}
   187  
   188  		if accessTokenRequest.clientId == TestingClientID {
   189  			if accessTokenRequest.deviceCode != TestingDeviceCode {
   190  				writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The device code is not valid")
   191  				return false
   192  			}
   193  		}
   194  
   195  		if accessTokenRequest.clientId == TestingClientIDAuthPending {
   196  			authPendingCount++
   197  			if accessTokenRequest.deviceCode == TestingDeviceAuthPending {
   198  				if authPendingCount < 3 {
   199  					writeError(w, http.StatusTooEarly, "authorization_pending", "")
   200  				} else {
   201  					//reset the authPendingCount
   202  					authPendingCount = 0
   203  					return true
   204  				}
   205  			} else {
   206  				writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The device code is not valid")
   207  			}
   208  
   209  			return false
   210  		}
   211  
   212  		if accessTokenRequest.clientId == TestingClientIDSlowDown {
   213  			slowDownCount++
   214  			if accessTokenRequest.deviceCode == TestingDeviceSlowDown {
   215  				if slowDownCount < 2 {
   216  					writeError(w, http.StatusTooEarly, "slow_down", "")
   217  				} else {
   218  					//reset the slowDownCount
   219  					slowDownCount = 0
   220  					return true
   221  				}
   222  			} else {
   223  				writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The device code is not valid")
   224  			}
   225  			return false
   226  		}
   227  
   228  		if accessTokenRequest.clientId == TestingClientIDAccessDenied {
   229  			if accessTokenRequest.deviceCode == TestingDeviceAccessDenied {
   230  				writeError(w, http.StatusUnauthorized, "access_denied", "")
   231  			} else {
   232  				writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The device code is not valid")
   233  			}
   234  			return false
   235  		}
   236  
   237  		if accessTokenRequest.clientId == TestingClientIDExpiredToken {
   238  			if accessTokenRequest.deviceCode == TestingDeviceExpiredToken {
   239  				writeError(w, http.StatusUnauthorized, "expired_token", "")
   240  			} else {
   241  				writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The device code is not valid")
   242  			}
   243  			return false
   244  		}
   245  	}
   246  
   247  	if accessTokenRequest.audience != "" && accessTokenRequest.audience != TestingAudience {
   248  		writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The audience is not valid")
   249  		return false
   250  	}
   251  
   252  	return true
   253  }
   254  
   255  func parseAccessTokenRequest(r *http.Request) (accessTokenRequest AccessTokenRequest, err error) {
   256  	err = r.ParseForm()
   257  	if err != nil {
   258  		return
   259  	}
   260  
   261  	accessTokenRequest = AccessTokenRequest{}
   262  
   263  	for key, value := range r.Form {
   264  		switch key {
   265  		case "grant_type":
   266  			accessTokenRequest.grantType = value[0]
   267  		case "client_id":
   268  			accessTokenRequest.clientId = value[0]
   269  		case "client_secret":
   270  			accessTokenRequest.clientSecret = value[0]
   271  		case "username":
   272  			accessTokenRequest.username = value[0]
   273  		case "password":
   274  			accessTokenRequest.password = value[0]
   275  		case "device_code":
   276  			accessTokenRequest.deviceCode = value[0]
   277  		case "scope":
   278  			accessTokenRequest.scope = value[0]
   279  		case "audience":
   280  			accessTokenRequest.audience = value[0]
   281  		}
   282  	}
   283  
   284  	//if the client_id was not as a query parameter
   285  	if accessTokenRequest.clientId == "" {
   286  		if username, password, ok := r.BasicAuth(); ok {
   287  			accessTokenRequest.clientId = username
   288  			accessTokenRequest.clientSecret = password
   289  		}
   290  	}
   291  
   292  	return
   293  }
   294  
   295  func processDeviceCodeRequest(w http.ResponseWriter, r *http.Request) {
   296  	var clientId, scope, audience, deviceCode string
   297  
   298  	err := r.ParseForm()
   299  	if err != nil {
   300  		writeError(w, http.StatusBadRequest, err.Error(), "")
   301  		return
   302  	}
   303  
   304  	// getting the clientID, the scope and the audience
   305  	for key, value := range r.Form {
   306  		switch key {
   307  		case "client_id":
   308  			clientId = value[0]
   309  		case "scope":
   310  			scope = value[0]
   311  		case "audience":
   312  			audience = value[0]
   313  		}
   314  	}
   315  
   316  	//if the client_id was not as a query parameter
   317  	if clientId == "" {
   318  		if username, _, ok := r.BasicAuth(); ok {
   319  			clientId = username
   320  		}
   321  	}
   322  
   323  	//validating the parameters gotten
   324  	if clientId == "" {
   325  		writeError(w, http.StatusBadRequest, "Status Bad Request", "The client_id is missing")
   326  		return
   327  	}
   328  
   329  	if scope != "" && scope != TestingScope {
   330  		writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The scope is not valid")
   331  		return
   332  	}
   333  
   334  	if audience != "" && audience != TestingAudience {
   335  		writeError(w, http.StatusUnauthorized, "Status Unauthorized Request", "The audience is not valid")
   336  		return
   337  	}
   338  
   339  	//Determining the deviceCode to send
   340  	switch clientId {
   341  	case TestingClientID:
   342  		deviceCode = TestingDeviceCode
   343  	case TestingClientIDAuthPending:
   344  		deviceCode = TestingDeviceAuthPending
   345  	case TestingClientIDSlowDown:
   346  		deviceCode = TestingDeviceSlowDown
   347  	case TestingClientIDAccessDenied:
   348  		deviceCode = TestingDeviceAccessDenied
   349  	case TestingClientIDExpiredToken:
   350  		deviceCode = TestingDeviceExpiredToken
   351  	}
   352  
   353  	//Headers must be set before the status and the body are written to the response.
   354  	w.Header().Set("Content-Type", "application/json")
   355  	w.WriteHeader(http.StatusOK)
   356  
   357  	deviceCred := DeviceCred{
   358  		DeviceCode:      deviceCode,
   359  		UserCode:        "1234",
   360  		VerificationURL: "",
   361  		VerificationURI: TestingDeviceVerificationUri,
   362  		Interval:        3,
   363  		ExpiresIn:       15,
   364  	}
   365  
   366  	jsonResp, err := json.Marshal(deviceCred)
   367  	if err != nil {
   368  		log.Fatalf("Error happened in JSON marshal. Err: %s", err)
   369  	}
   370  	w.Write(jsonResp)
   371  
   372  	return
   373  }