github.com/verrazzano/verrazzano@v1.7.0/authproxy/src/apiserver/apiserver_test.go (about)

     1  // Copyright (c) 2023, Oracle and/or its affiliates.
     2  // Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl.
     3  
     4  package apiserver
     5  
     6  import (
     7  	"encoding/base64"
     8  	"encoding/json"
     9  	"fmt"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/hashicorp/go-retryablehttp"
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/verrazzano/verrazzano/authproxy/internal/testutil/testauth"
    18  	"github.com/verrazzano/verrazzano/authproxy/src/auth"
    19  	"go.uber.org/zap"
    20  )
    21  
    22  const (
    23  	apiPath          = "/api/v1/pods"
    24  	testAPIServerURL = "https://api-server.io"
    25  )
    26  
    27  // TestForwardAPIRequest tests that API requests are properly formatted and sent to the API server
    28  func TestForwardAPIRequest(t *testing.T) {
    29  	tests := []struct {
    30  		name             string
    31  		reqMethod        string
    32  		reqHeaders       map[string]string
    33  		expectedStatus   int
    34  		expectedRespHdrs map[string]string
    35  		unauthenticated  bool
    36  	}{
    37  		// GIVEN an options request
    38  		// WHEN  the request is received
    39  		// THEN  the content length header is set
    40  		{
    41  			name:           "options request",
    42  			reqMethod:      http.MethodOptions,
    43  			expectedStatus: http.StatusOK,
    44  			expectedRespHdrs: map[string]string{
    45  				"Content-Length": "0",
    46  			},
    47  		},
    48  		// GIVEN a processed request
    49  		// WHEN  the request is received
    50  		// THEN  an OK response is returned
    51  		{
    52  			name:            "processed request",
    53  			reqMethod:       http.MethodGet,
    54  			expectedStatus:  http.StatusOK,
    55  			unauthenticated: true,
    56  		},
    57  		// GIVEN a get request
    58  		// WHEN  the request is authorized
    59  		// THEN  the status returned is okay
    60  		{
    61  			name:           "get request",
    62  			reqMethod:      http.MethodGet,
    63  			expectedStatus: http.StatusOK,
    64  		},
    65  		// GIVEN a post request with headers
    66  		// WHEN  the request is forwarded
    67  		// THEN  the headers are properly added to the request
    68  		{
    69  			name:      "post request with headers",
    70  			reqMethod: http.MethodPost,
    71  			reqHeaders: map[string]string{
    72  				"test1": "header1",
    73  				"test2": "header2",
    74  			},
    75  			expectedStatus: http.StatusOK,
    76  		},
    77  	}
    78  	for _, tt := range tests {
    79  		t.Run(tt.name, func(t *testing.T) {
    80  			server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    81  				assert.Equal(t, tt.reqMethod, r.Method)
    82  				for k, v := range tt.reqHeaders {
    83  					assert.Contains(t, r.Header.Get(k), v)
    84  				}
    85  			}))
    86  			defer server.Close()
    87  
    88  			url := fmt.Sprintf("%s/clusters/local%s", testAPIServerURL, apiPath)
    89  			w := httptest.NewRecorder()
    90  			cli := retryablehttp.NewClient()
    91  			request := httptest.NewRequest(tt.reqMethod, url, strings.NewReader(""))
    92  			for k, v := range tt.reqHeaders {
    93  				request.Header.Set(k, v)
    94  			}
    95  			setEmptyToken(request)
    96  			authenticator := testauth.NewFakeAuthenticator()
    97  
    98  			apiRequest := APIRequest{
    99  				Request:       request,
   100  				RW:            w,
   101  				Client:        cli,
   102  				Authenticator: authenticator,
   103  				APIServerURL:  server.URL,
   104  				Log:           zap.S(),
   105  			}
   106  
   107  			if tt.unauthenticated {
   108  				authenticator.SetRequestFunc(testauth.AuthenticateFalse)
   109  				defer authenticator.SetRequestFunc(testauth.AuthenticateTrue)
   110  			}
   111  
   112  			apiRequest.ForwardAPIRequest()
   113  			assert.Equal(t, tt.expectedStatus, w.Code)
   114  
   115  			for k, v := range tt.expectedRespHdrs {
   116  				assert.Equal(t, v, w.Header().Get(k))
   117  			}
   118  
   119  		})
   120  	}
   121  }
   122  
   123  // TestReformatAPIRequest tests the reformatting of the request to be sent to the API server
   124  
   125  func TestReformatAPIRequest(t *testing.T) {
   126  	apiRequest := APIRequest{
   127  		APIServerURL: testAPIServerURL,
   128  		Client:       retryablehttp.NewClient(),
   129  		Log:          zap.S(),
   130  	}
   131  
   132  	tests := []struct {
   133  		name        string
   134  		url         string
   135  		expectedURL string
   136  	}{
   137  		// GIVEN a request to the Auth proxy server
   138  		// WHEN  the request is formatted correctly
   139  		// THEN  the request is properly formatted to be sent to the API server
   140  		{
   141  			name:        "test cluster path",
   142  			url:         fmt.Sprintf("https://authproxy.io/clusters/local%s", apiPath),
   143  			expectedURL: fmt.Sprintf("%s%s", apiRequest.APIServerURL, apiPath),
   144  		},
   145  		// GIVEN a request to the Auth proxy server
   146  		// WHEN  the request is malformed
   147  		// THEN  a malformed request is returned
   148  		{
   149  			name:        "test malformed request",
   150  			url:         "https://authproxy.io/malformedrequest1234",
   151  			expectedURL: fmt.Sprintf("%s/%s", apiRequest.APIServerURL, "malformedrequest1234"),
   152  		},
   153  		// GIVEN a request to the Auth proxy server
   154  		// WHEN  the request has a query param
   155  		// THEN  the query param is added to the outgoing request
   156  		{
   157  			name:        "test query param",
   158  			url:         fmt.Sprintf("https://authproxy.io/clusters/local%s?watch=1", apiPath),
   159  			expectedURL: fmt.Sprintf("%s%s?watch=1", apiRequest.APIServerURL, apiPath),
   160  		},
   161  	}
   162  	for _, tt := range tests {
   163  		t.Run(tt.name, func(t *testing.T) {
   164  			req := httptest.NewRequest(http.MethodGet, tt.url, strings.NewReader(""))
   165  			setEmptyToken(req)
   166  
   167  			formattedReq, err := apiRequest.reformatAPIRequest(req)
   168  			assert.NoError(t, err)
   169  			assert.NotNil(t, formattedReq.URL)
   170  			assert.Equal(t, tt.expectedURL, formattedReq.URL.String())
   171  		})
   172  	}
   173  }
   174  
   175  // TestSetImpersonationHeaders tests that the impersonation headers can be set for an API server request
   176  func TestSetImpersonationHeaders(t *testing.T) {
   177  	// GIVEN a request with a bad JWT token
   178  	// WHEN  the request is evaluated
   179  	// THEN  an error is returned
   180  	req := &http.Request{
   181  		Header: map[string][]string{
   182  			"Authorization": {
   183  				"bad-jwt-token",
   184  			},
   185  		},
   186  	}
   187  	err := setImpersonationHeaders(req)
   188  	assert.Error(t, err)
   189  
   190  	// GIVEN a request with a valid JWT token
   191  	// WHEN  the request is evaluated
   192  	// THEN  the request has the impersonation headers set
   193  	testUser := "test-user"
   194  	testGroups := []string{
   195  		"group1",
   196  		"group2",
   197  	}
   198  	headers := auth.ImpersonationHeaders{
   199  		User:   testUser,
   200  		Groups: testGroups,
   201  	}
   202  	tokenJSON, err := json.Marshal(headers)
   203  	assert.NoError(t, err)
   204  
   205  	tokenBase64 := base64.RawURLEncoding.EncodeToString(tokenJSON)
   206  	jwtToken := fmt.Sprintf("test.%s.test", tokenBase64)
   207  
   208  	req = &http.Request{
   209  		Header: map[string][]string{
   210  			"Authorization": {
   211  				"Bearer " + jwtToken,
   212  			},
   213  		},
   214  	}
   215  	err = setImpersonationHeaders(req)
   216  	assert.NoError(t, err)
   217  	assert.Len(t, req.Header.Values(userImpersontaionHeader), 1)
   218  	assert.Equal(t, testUser, req.Header.Get(userImpersontaionHeader))
   219  	assert.ElementsMatch(t, testGroups, req.Header.Values(groupImpersonationHeader))
   220  }
   221  
   222  // TestValidateRequest tests the request validation for the Auth Proxy
   223  func TestValidateRequest(t *testing.T) {
   224  	// GIVEN a request without the cluster path
   225  	// WHEN  the request is validated
   226  	// THEN  an error is returned
   227  	url := fmt.Sprintf("%s/%s", testAPIServerURL, apiPath)
   228  	req, err := http.NewRequest(http.MethodGet, url, strings.NewReader(""))
   229  	assert.NoError(t, err)
   230  	err = validateRequest(req)
   231  	assert.Error(t, err)
   232  
   233  	// GIVEN a request with the cluster path
   234  	// WHEN  the request is validated
   235  	// THEN  no error is returned
   236  	url = fmt.Sprintf("%s/clusters/local%s", testAPIServerURL, apiPath)
   237  	req, err = http.NewRequest(http.MethodGet, url, strings.NewReader(""))
   238  	assert.NoError(t, err)
   239  	err = validateRequest(req)
   240  	assert.NoError(t, err)
   241  }
   242  
   243  func setEmptyToken(req *http.Request) {
   244  	testToken := fmt.Sprintf("info.%s.info", base64.RawURLEncoding.EncodeToString([]byte("{}")))
   245  	req.Header.Set("Authorization", "Bearer "+testToken)
   246  }