github.com/Axway/agent-sdk@v1.1.101/pkg/authz/oauth/authclient_test.go (about)

     1  package oauth
     2  
     3  import (
     4  	"encoding/json"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"net/url"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/Axway/agent-sdk/pkg/api"
    12  	"github.com/Axway/agent-sdk/pkg/config"
    13  	"github.com/Axway/agent-sdk/pkg/util"
    14  	"github.com/stretchr/testify/assert"
    15  )
    16  
    17  func assertHeaders(t *testing.T, expected map[string]string, actual http.Header) {
    18  	for key, val := range expected {
    19  		actualVal := actual.Get(key)
    20  		assert.Equal(t, val, actualVal)
    21  	}
    22  }
    23  
    24  func assertQueryParams(t *testing.T, expected map[string]string, actual url.Values) {
    25  	for key, val := range expected {
    26  		actualVal := actual.Get(key)
    27  		assert.Equal(t, val, actualVal)
    28  	}
    29  }
    30  
    31  func TestResponseDecode(t *testing.T) {
    32  	fixture := `
    33  {
    34    "access_token": "some_value",
    35    "expires_in": 1800,
    36    "refresh_expires_in": 21600,
    37    "refresh_token": "some_value",
    38    "token_type": "bearer",
    39    "not-before-policy": 1510148785,
    40    "session_state": "f4b0fe58-a6f7-4452-9010-3945a7ecd493"
    41  }`
    42  
    43  	tokens := &tokenResponse{}
    44  	json.Unmarshal([]byte(fixture), tokens)
    45  	if tokens.AccessToken != "some_value" {
    46  		t.Error("unexpected access token value")
    47  	}
    48  	if tokens.ExpiresIn != 1800 {
    49  		t.Error("unexpected expires in token")
    50  	}
    51  }
    52  
    53  func TestEmptyTokenHolder(t *testing.T) {
    54  	ac := &authClient{}
    55  
    56  	if token := ac.getCachedToken(); token != "" {
    57  		t.Error("unexpected token from cache")
    58  	}
    59  }
    60  
    61  func TestExpiredTokenHolder(t *testing.T) {
    62  	ac := &authClient{
    63  		cachedToken: &tokenResponse{
    64  			AccessToken: "some_token",
    65  		},
    66  		cachedTokenExpiry: time.Now().Add(-time.Hour),
    67  	}
    68  
    69  	time.Sleep(time.Millisecond)
    70  
    71  	if token := ac.getCachedToken(); token != "" {
    72  		t.Error("unexpected token from cache")
    73  	}
    74  }
    75  
    76  func TestGetPlatformTokensHttpError(t *testing.T) {
    77  	s := NewMockIDPServer()
    78  	defer s.Close()
    79  
    80  	apiClient := api.NewClient(config.NewTLSConfig(), "")
    81  	s.SetTokenResponse("", 0, http.StatusBadRequest)
    82  	ac, err := NewAuthClient(s.GetTokenURL(), apiClient,
    83  		WithServerName("testServer"),
    84  		WithClientSecretPostAuth("invalid_client", "invalid-secrt", ""))
    85  	assert.Nil(t, err)
    86  	assert.NotNil(t, ac)
    87  
    88  	_, err = ac.GetToken()
    89  	assert.NotNil(t, err)
    90  
    91  	privateKey, _ := util.ReadPrivateKeyFile("testdata/private_key.pem", "")
    92  	publicKey, _ := util.ReadPublicKeyBytes("testdata/publickey")
    93  	s.SetTokenResponse("", 0, http.StatusBadRequest)
    94  	ac, err = NewAuthClient(s.GetTokenURL(), apiClient,
    95  		WithServerName("testServer"),
    96  		WithKeyPairAuth("invalid_client", "", "", privateKey, publicKey, "", ""))
    97  	assert.Nil(t, err)
    98  	assert.NotNil(t, ac)
    99  
   100  	_, err = ac.GetToken()
   101  	assert.NotNil(t, err)
   102  
   103  	s.SetTokenResponse("token", 3*time.Second, http.StatusOK)
   104  	ac, err = NewAuthClient(s.GetTokenURL(), apiClient,
   105  		WithServerName("testServer"),
   106  		WithKeyPairAuth("invalid_client", "", "", privateKey, publicKey, "", ""))
   107  	assert.Nil(t, err)
   108  	assert.NotNil(t, ac)
   109  
   110  	token, err := ac.GetToken()
   111  	assert.Nil(t, err)
   112  	assert.Equal(t, "token", token)
   113  }
   114  
   115  func TestGetPlatformTokensTimeout(t *testing.T) {
   116  	s := httptest.NewServer(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
   117  		time.Sleep(2 * time.Second)
   118  	}))
   119  	defer s.Close()
   120  
   121  	apiClient := api.NewClientWithTimeout(config.NewTLSConfig(), "", time.Second)
   122  	ac, err := NewAuthClient(s.URL, apiClient,
   123  		WithServerName("testServer"),
   124  		WithClientSecretPostAuth("invalid_client", "invalid-secrt", ""))
   125  
   126  	assert.Nil(t, err)
   127  	assert.NotNil(t, ac)
   128  
   129  	_, err = ac.GetToken()
   130  	assert.NotNil(t, err)
   131  }
   132  
   133  func TestAuthClientTypes(t *testing.T) {
   134  	s := NewMockIDPServer()
   135  	defer s.Close()
   136  	keyReader := NewKeyReader(
   137  		"testdata/private_key.pem",
   138  		"testdata/publickey",
   139  		"",
   140  	)
   141  	privateKey, keyErr := keyReader.GetPrivateKey()
   142  	assert.Nil(t, keyErr)
   143  
   144  	publicKey, keyErr := keyReader.GetPublicKey()
   145  	assert.Nil(t, keyErr)
   146  
   147  	cases := []struct {
   148  		name                             string
   149  		tokenReqWithAuthorization        bool
   150  		typedAuthOpt                     AuthClientOption
   151  		headers                          map[string]string
   152  		queryParams                      map[string]string
   153  		expectedTokenReqClientID         string
   154  		expectedTokenReqClientSecret     string
   155  		expectedTokenReqClientAssertType string
   156  		expectedTokenReqScope            string
   157  	}{
   158  		{
   159  			name:                      "test",
   160  			headers:                   map[string]string{"hdr": "val"},
   161  			queryParams:               map[string]string{"param": "param-val"},
   162  			typedAuthOpt:              WithClientSecretBasicAuth("test-id", "test-secret", "test-scope"),
   163  			tokenReqWithAuthorization: true,
   164  			expectedTokenReqScope:     "test-scope",
   165  		},
   166  		{
   167  			name:                         "test",
   168  			typedAuthOpt:                 WithClientSecretPostAuth("test-id", "test-secret", "test-scope"),
   169  			expectedTokenReqClientID:     "test-id",
   170  			expectedTokenReqClientSecret: "test-secret",
   171  			expectedTokenReqScope:        "test-scope",
   172  		},
   173  		{
   174  			name:                             "test",
   175  			typedAuthOpt:                     WithClientSecretJwtAuth("test-id", "test-secret", "test-scope", "", "aud", ""),
   176  			expectedTokenReqClientID:         "test-id",
   177  			expectedTokenReqScope:            "test-scope",
   178  			expectedTokenReqClientAssertType: "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
   179  		},
   180  		{
   181  			name:                             "test",
   182  			typedAuthOpt:                     WithKeyPairAuth("test-id", "", "aud", privateKey, publicKey, "test-scope", ""),
   183  			expectedTokenReqScope:            "test-scope",
   184  			expectedTokenReqClientAssertType: "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
   185  		},
   186  		{
   187  			name:                     "test",
   188  			typedAuthOpt:             WithTLSClientAuth("test-id", "test-scope"),
   189  			expectedTokenReqClientID: "test-id",
   190  			expectedTokenReqScope:    "test-scope",
   191  		},
   192  	}
   193  	for _, tc := range cases {
   194  		s.SetTokenResponse("token", 3*time.Second, http.StatusOK)
   195  		apiClient := api.NewClientWithTimeout(config.NewTLSConfig(), "", time.Second)
   196  		opts := []AuthClientOption{
   197  			WithServerName("testServer"),
   198  			tc.typedAuthOpt,
   199  			WithRequestHeaders(tc.headers),
   200  			WithQueryParams(tc.queryParams),
   201  		}
   202  		ac, err := NewAuthClient(s.GetTokenURL(), apiClient, opts...)
   203  
   204  		assert.Nil(t, err)
   205  		assert.NotNil(t, ac)
   206  
   207  		token, err := ac.GetToken()
   208  		assert.Nil(t, err)
   209  		assert.Equal(t, "token", token)
   210  		if tc.tokenReqWithAuthorization {
   211  			headers := s.GetTokenRequestHeaders()
   212  			authHeaderVal := headers.Get("Authorization")
   213  			assert.NotEmpty(t, authHeaderVal)
   214  		}
   215  		tokenReqValues := s.GetTokenRequestValues()
   216  
   217  		grantType := tokenReqValues.Get("grant_type")
   218  		assert.Equal(t, "client_credentials", grantType)
   219  
   220  		clientID := tokenReqValues.Get("client_id")
   221  		assert.Equal(t, tc.expectedTokenReqClientID, clientID)
   222  
   223  		clientSecret := tokenReqValues.Get("client_secret")
   224  		assert.Equal(t, tc.expectedTokenReqClientSecret, clientSecret)
   225  
   226  		clientAssertType := tokenReqValues.Get("client_assertion_type")
   227  		assert.Equal(t, tc.expectedTokenReqClientAssertType, clientAssertType)
   228  
   229  		tokenScope := tokenReqValues.Get("scope")
   230  		assert.Equal(t, tc.expectedTokenReqScope, tokenScope)
   231  
   232  		assertHeaders(t, tc.headers, s.GetTokenRequestHeaders())
   233  		assertQueryParams(t, tc.queryParams, s.GetTokenQueryParams())
   234  	}
   235  }