github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/internal/securehttp/caller_test.go (about)

     1  package securehttp_test
     2  
     3  import (
     4  	"encoding/base64"
     5  	"encoding/json"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/kyma-incubator/compass/components/director/pkg/auth"
    14  
    15  	"golang.org/x/oauth2"
    16  
    17  	"github.com/stretchr/testify/require"
    18  
    19  	"github.com/kyma-incubator/compass/components/director/internal/securehttp"
    20  )
    21  
    22  const (
    23  	testClientID        = "client-id"
    24  	testClientIDKey     = "client_id"
    25  	testGrantType       = "client_credentials"
    26  	testGrantTypeKey    = "grant_type"
    27  	testClientSecret    = "client-secret"
    28  	testUser            = "user"
    29  	testPassword        = "pass"
    30  	testToken           = "bGFzbWFnaS1qYXNtYWdpLWtyaXphLWU="
    31  	basicPrefix         = "Basic "
    32  	oauthPrefix         = "Bearer "
    33  	authorizationHeader = "Authorization"
    34  	testScopes          = "scopes"
    35  	testScopesKey       = "scope"
    36  )
    37  
    38  func TestCaller_Call(t *testing.T) {
    39  	oauthServerExpectingCredentialsFromHeader := httptest.NewServer(getTestOauthServer(t, requireClientCredentialsFromHeader))
    40  	oauthCredentials := &auth.OAuthCredentials{
    41  		ClientID:     testClientID,
    42  		ClientSecret: testClientSecret,
    43  		TokenURL:     oauthServerExpectingCredentialsFromHeader.URL,
    44  		Scopes:       testScopes,
    45  	}
    46  
    47  	oauthServerExpectingCredentialsFromBody := httptest.NewServer(getTestOauthServer(t, requireClientCredentialsFromBody))
    48  	oauthMtlsCredentials := &auth.OAuthMtlsCredentials{
    49  		ClientID: testClientID,
    50  		TokenURL: oauthServerExpectingCredentialsFromBody.URL,
    51  		Scopes:   testScopes,
    52  	}
    53  
    54  	basicCredentials := &auth.BasicCredentials{
    55  		Username: testUser,
    56  		Password: testPassword,
    57  	}
    58  
    59  	testCases := []struct {
    60  		Name        string
    61  		Server      *httptest.Server
    62  		Config      securehttp.CallerConfig
    63  		ExpectedErr error
    64  	}{
    65  		{
    66  			Name: "Success for oauth credentials with secret",
    67  			Config: securehttp.CallerConfig{
    68  				Credentials:       oauthCredentials,
    69  				ClientTimeout:     time.Second,
    70  				SkipSSLValidation: true,
    71  			},
    72  			ExpectedErr: nil,
    73  			Server: httptest.NewServer(http.HandlerFunc(
    74  				func(w http.ResponseWriter, req *http.Request) {
    75  					token := req.Header.Get(authorizationHeader)
    76  					token = strings.TrimPrefix(token, oauthPrefix)
    77  					require.Equal(t, testToken, token)
    78  				}),
    79  			),
    80  		},
    81  		{
    82  			Name: "Success for oauth Mtls",
    83  			Config: securehttp.CallerConfig{
    84  				Credentials:       oauthMtlsCredentials,
    85  				ClientTimeout:     time.Second,
    86  				SkipSSLValidation: true,
    87  			},
    88  			ExpectedErr: nil,
    89  			Server: httptest.NewServer(http.HandlerFunc(
    90  				func(w http.ResponseWriter, req *http.Request) {
    91  					token := req.Header.Get(authorizationHeader)
    92  					token = strings.TrimPrefix(token, oauthPrefix)
    93  					require.Equal(t, testToken, token)
    94  				}),
    95  			),
    96  		},
    97  		{
    98  			Name: "Success for basic credentials",
    99  			Config: securehttp.CallerConfig{
   100  				Credentials:   basicCredentials,
   101  				ClientTimeout: time.Second,
   102  			},
   103  			ExpectedErr: nil,
   104  			Server: httptest.NewServer(http.HandlerFunc(
   105  				func(w http.ResponseWriter, req *http.Request) {
   106  					credentials := getBase64EncodedCredentials(t, req, basicPrefix)
   107  					require.Equal(t, testUser+":"+testPassword, credentials)
   108  				}),
   109  			),
   110  		},
   111  	}
   112  	for _, testCase := range testCases {
   113  		t.Run(testCase.Name, func(t *testing.T) {
   114  			caller, err := securehttp.NewCaller(testCase.Config)
   115  			require.NoError(t, err)
   116  			request, err := http.NewRequest(http.MethodGet, testCase.Server.URL, nil)
   117  			require.NoError(t, err)
   118  
   119  			_, err = caller.Call(request)
   120  			if testCase.ExpectedErr != nil {
   121  				require.Error(t, err)
   122  				require.Contains(t, err.Error(), testCase.ExpectedErr.Error())
   123  			} else {
   124  				require.NoError(t, err)
   125  			}
   126  		})
   127  	}
   128  }
   129  
   130  func getTestOauthServer(t *testing.T, assertCredentials func(t *testing.T, r *http.Request)) http.HandlerFunc {
   131  	return func(w http.ResponseWriter, req *http.Request) {
   132  		assertCredentials(t, req)
   133  
   134  		err := json.NewEncoder(w).Encode(oauth2.Token{
   135  			AccessToken: testToken,
   136  			Expiry:      time.Now().Add(time.Minute),
   137  		})
   138  		require.NoError(t, err)
   139  	}
   140  }
   141  
   142  func requireClientCredentialsFromHeader(t *testing.T, req *http.Request) {
   143  	credsStr := getBase64EncodedCredentials(t, req, basicPrefix)
   144  	require.Equal(t, testClientID+":"+testClientSecret, credsStr)
   145  }
   146  
   147  func requireClientCredentialsFromBody(t *testing.T, req *http.Request) {
   148  	requestBody, err := io.ReadAll(req.Body)
   149  	require.NoError(t, err)
   150  	require.Equal(t, testClientIDKey+"="+testClientID+"&"+testGrantTypeKey+"="+testGrantType+"&"+testScopesKey+"="+testScopes, string(requestBody))
   151  }
   152  
   153  func getBase64EncodedCredentials(t *testing.T, r *http.Request, prefix string) string {
   154  	creds := r.Header.Get(authorizationHeader)
   155  	creds = strings.TrimPrefix(creds, prefix)
   156  
   157  	credsDecoded, err := base64.StdEncoding.DecodeString(creds)
   158  	require.NoError(t, err)
   159  	return string(credsDecoded)
   160  }