github.com/Azure/aad-pod-identity@v1.8.17/pkg/nmi/server/server_test.go (about)

     1  package server
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"net/url"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/gorilla/mux"
    13  )
    14  
    15  var (
    16  	rtr       *mux.Router
    17  	server    *httptest.Server
    18  	tokenPath = "/metadata/identity/oauth2/token/"
    19  )
    20  
    21  func setup() {
    22  	rtr = mux.NewRouter()
    23  	server = httptest.NewServer(rtr)
    24  }
    25  
    26  func teardown() {
    27  	server.Close()
    28  }
    29  
    30  func TestMsiHandler_NoMetadataHeader(t *testing.T) {
    31  	setup()
    32  	defer teardown()
    33  
    34  	s := &Server{
    35  		MetadataHeaderRequired: true,
    36  	}
    37  	rtr.PathPrefix("/{type:(?i:metadata)}/identity/oauth2/token/").Handler(appHandler(s.msiHandler))
    38  
    39  	req, err := http.NewRequest(http.MethodGet, tokenPath, nil)
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  
    44  	recorder := httptest.NewRecorder()
    45  	rtr.ServeHTTP(recorder, req)
    46  
    47  	if recorder.Code != http.StatusBadRequest {
    48  		t.Errorf("Unexpected status code %d", recorder.Code)
    49  	}
    50  
    51  	resp := &MetadataResponse{
    52  		Error:            "invalid_request",
    53  		ErrorDescription: "Required metadata header not specified",
    54  	}
    55  	expected, err := json.Marshal(resp)
    56  	if err != nil {
    57  		t.Fatal(err)
    58  	}
    59  
    60  	if string(expected) != strings.TrimSpace(recorder.Body.String()) {
    61  		t.Errorf("Unexpected response body %s", recorder.Body.String())
    62  	}
    63  }
    64  
    65  func TestMsiHandler_NoRemoteAddress(t *testing.T) {
    66  	setup()
    67  	defer teardown()
    68  
    69  	s := &Server{
    70  		MetadataHeaderRequired: false,
    71  	}
    72  	rtr.PathPrefix("/{type:(?i:metadata)}/identity/oauth2/token/").Handler(appHandler(s.msiHandler))
    73  
    74  	req, err := http.NewRequest(http.MethodGet, tokenPath, nil)
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  
    79  	recorder := httptest.NewRecorder()
    80  	rtr.ServeHTTP(recorder, req)
    81  
    82  	if recorder.Code != http.StatusInternalServerError {
    83  		t.Errorf("Unexpected status code %d", recorder.Code)
    84  	}
    85  
    86  	expected := "request remote address is empty"
    87  	if expected != strings.TrimSpace(recorder.Body.String()) {
    88  		t.Errorf("Unexpected response body %s", recorder.Body.String())
    89  	}
    90  }
    91  
    92  func TestParseTokenRequest(t *testing.T) {
    93  	const endpoint = "http://127.0.0.1/metadata/identity/oauth2/token"
    94  
    95  	t.Run("query present", func(t *testing.T) {
    96  		const resource = "https://vault.azure.net"
    97  		const clientID = "77788899-f67e-42e1-9a78-89985f6bff3e"
    98  		const resourceID = "/subscriptions/9f2be85c-f8ae-4569-9353-38e5e8b459ef/resourcegroups/test/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test"
    99  
   100  		var r http.Request
   101  		r.URL, _ = url.Parse(fmt.Sprintf("%s?client_id=%s&msi_res_id=%s&resource=%s", endpoint, clientID, resourceID, resource))
   102  
   103  		result := parseTokenRequest(&r)
   104  
   105  		if result.ClientID != clientID {
   106  			t.Errorf("invalid ClientID - expected: %q, actual: %q", clientID, result.ClientID)
   107  		}
   108  
   109  		if result.ResourceID != resourceID {
   110  			t.Errorf("invalid ResourceID - expected: %q, actual: %q", resourceID, result.ResourceID)
   111  		}
   112  
   113  		if result.Resource != resource {
   114  			t.Errorf("invalid Resource - expected: %q, actual: %q", resource, result.Resource)
   115  		}
   116  	})
   117  
   118  	t.Run("query present with latest resource id", func(t *testing.T) {
   119  		const resource = "https://vault.azure.net"
   120  		const resourceID = "/subscriptions/9f2be85c-f8ae-4569-9353-38e5e8b459ef/resourcegroups/test/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test"
   121  
   122  		var r http.Request
   123  		r.URL, _ = url.Parse(fmt.Sprintf("%s?mi_res_id=%s&resource=%s", endpoint, resourceID, resource))
   124  
   125  		result := parseTokenRequest(&r)
   126  
   127  		if result.ResourceID != resourceID {
   128  			t.Errorf("invalid ResourceID - expected: %q, actual: %q", resourceID, result.ResourceID)
   129  		}
   130  	})
   131  
   132  	t.Run("bare endpoint", func(t *testing.T) {
   133  		var r http.Request
   134  		r.URL, _ = url.Parse(endpoint)
   135  
   136  		result := parseTokenRequest(&r)
   137  
   138  		if result.ClientID != "" {
   139  			t.Errorf("invalid ClientID - expected: %q, actual: %q", "", result.ClientID)
   140  		}
   141  
   142  		if result.ResourceID != "" {
   143  			t.Errorf("invalid ResourceID - expected: %q, actual: %q", "", result.ResourceID)
   144  		}
   145  
   146  		if result.Resource != "" {
   147  			t.Errorf("invalid Resource - expected: %q, actual: %q", "", result.Resource)
   148  		}
   149  	})
   150  }
   151  
   152  func TestTokenRequest_ValidateResourceParamExists(t *testing.T) {
   153  	tr := TokenRequest{
   154  		Resource: "https://vault.azure.net",
   155  	}
   156  
   157  	if !tr.ValidateResourceParamExists() {
   158  		t.Error("ValidateResourceParamExists should have returned true when the resource is set")
   159  	}
   160  
   161  	tr.Resource = ""
   162  	if tr.ValidateResourceParamExists() {
   163  		t.Error("ValidateResourceParamExists should have returned false when the resource is unset")
   164  	}
   165  }
   166  
   167  func TestRouterPathPrefix(t *testing.T) {
   168  	tests := []struct {
   169  		name               string
   170  		url                string
   171  		expectedStatusCode int
   172  		expectedBody       string
   173  	}{
   174  		{
   175  			name:               "token request",
   176  			url:                "/metadata/identity/oauth2/token/",
   177  			expectedStatusCode: http.StatusOK,
   178  			expectedBody:       "token_request_handler",
   179  		},
   180  		{
   181  			name:               "token request without / suffix",
   182  			url:                "/metadata/identity/oauth2/token",
   183  			expectedStatusCode: http.StatusOK,
   184  			expectedBody:       "token_request_handler",
   185  		},
   186  		{
   187  			name:               "token request with upper case metadata",
   188  			url:                "/Metadata/identity/oauth2/token/",
   189  			expectedStatusCode: http.StatusOK,
   190  			expectedBody:       "token_request_handler",
   191  		},
   192  		{
   193  			name:               "token request with upper case identity",
   194  			url:                "/metadata/Identity/oauth2/token/",
   195  			expectedStatusCode: http.StatusOK,
   196  			expectedBody:       "default_handler",
   197  		},
   198  		{
   199  			name:               "host token request",
   200  			url:                "/host/token/",
   201  			expectedStatusCode: http.StatusOK,
   202  			expectedBody:       "host_token_request_handler",
   203  		},
   204  		{
   205  			name:               "host token request without / suffix",
   206  			url:                "/host/token",
   207  			expectedStatusCode: http.StatusOK,
   208  			expectedBody:       "host_token_request_handler",
   209  		},
   210  		{
   211  			name:               "instance metadata request",
   212  			url:                "/metadata/instance",
   213  			expectedStatusCode: http.StatusOK,
   214  			expectedBody:       "instance_request_handler",
   215  		},
   216  		{
   217  			name:               "instance metadata request with upper case metadata",
   218  			url:                "/Metadata/instance",
   219  			expectedStatusCode: http.StatusOK,
   220  			expectedBody:       "instance_request_handler",
   221  		},
   222  		{
   223  			name:               "instance metadata request / suffix",
   224  			url:                "/Metadata/instance/",
   225  			expectedStatusCode: http.StatusOK,
   226  			expectedBody:       "instance_request_handler",
   227  		},
   228  		{
   229  			name:               "default metadata request",
   230  			url:                "/metadata/",
   231  			expectedStatusCode: http.StatusOK,
   232  			expectedBody:       "default_handler",
   233  		},
   234  		{
   235  			name:               "invalid token request with \\oauth2",
   236  			url:                `/metadata/identity\oauth2/token/`,
   237  			expectedStatusCode: http.StatusOK,
   238  			expectedBody:       "invalid_request_handler",
   239  		},
   240  		{
   241  			name:               "invalid token request with \\token",
   242  			url:                `/metadata/identity/oauth2\token/`,
   243  			expectedStatusCode: http.StatusOK,
   244  			expectedBody:       "invalid_request_handler",
   245  		},
   246  		{
   247  			name:               "invalid token request with \\oauth2\\token",
   248  			url:                `/metadata/identity\oauth2\token/`,
   249  			expectedStatusCode: http.StatusOK,
   250  			expectedBody:       "invalid_request_handler",
   251  		},
   252  		{
   253  			name:               "invalid token request with mix of / and \\",
   254  			url:                `/metadata/identity/\oauth2\token/`,
   255  			expectedStatusCode: http.StatusOK,
   256  			expectedBody:       "invalid_request_handler",
   257  		},
   258  		{
   259  			name:               "invalid token request with multiple \\",
   260  			url:                `/metadata/identity\\\oauth2\\token/`,
   261  			expectedStatusCode: http.StatusOK,
   262  			expectedBody:       "invalid_request_handler",
   263  		},
   264  	}
   265  
   266  	for _, test := range tests {
   267  		t.Run(test.name, func(t *testing.T) {
   268  			setup()
   269  			defer teardown()
   270  
   271  			rtr.PathPrefix(tokenPathPrefix).HandlerFunc(testTokenHandler)
   272  			rtr.MatcherFunc(invalidTokenPathMatcher).HandlerFunc(testInvalidRequestHandler)
   273  			rtr.PathPrefix(hostTokenPathPrefix).HandlerFunc(testHostTokenHandler)
   274  			rtr.PathPrefix(instancePathPrefix).HandlerFunc(testInstanceHandler)
   275  			rtr.PathPrefix("/").HandlerFunc(testDefaultHandler)
   276  
   277  			req, err := http.NewRequest(http.MethodGet, test.url, nil)
   278  			if err != nil {
   279  				t.Fatal(err)
   280  			}
   281  
   282  			recorder := httptest.NewRecorder()
   283  			rtr.ServeHTTP(recorder, req)
   284  
   285  			if recorder.Code != test.expectedStatusCode {
   286  				t.Errorf("unexpected status code %d", recorder.Code)
   287  			}
   288  
   289  			if test.expectedBody != strings.TrimSpace(recorder.Body.String()) {
   290  				t.Errorf("unexpected response body %s", recorder.Body.String())
   291  			}
   292  		})
   293  	}
   294  }
   295  
   296  func testTokenHandler(w http.ResponseWriter, r *http.Request) {
   297  	fmt.Fprintf(w, "token_request_handler\n")
   298  }
   299  
   300  func testHostTokenHandler(w http.ResponseWriter, r *http.Request) {
   301  	fmt.Fprintf(w, "host_token_request_handler\n")
   302  }
   303  
   304  func testInstanceHandler(w http.ResponseWriter, r *http.Request) {
   305  	fmt.Fprintf(w, "instance_request_handler\n")
   306  }
   307  
   308  func testDefaultHandler(w http.ResponseWriter, r *http.Request) {
   309  	fmt.Fprintf(w, "default_handler\n")
   310  }
   311  
   312  func testInvalidRequestHandler(w http.ResponseWriter, r *http.Request) {
   313  	fmt.Fprintf(w, "invalid_request_handler\n")
   314  }