github.com/Azure/aad-pod-identity@v1.8.17/pkg/nmi/server/server_test.go (about) 1 package server 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "net/http" 7 "net/http/httptest" 8 "net/url" 9 "strings" 10 "testing" 11 12 "github.com/gorilla/mux" 13 ) 14 15 var ( 16 rtr *mux.Router 17 server *httptest.Server 18 tokenPath = "/metadata/identity/oauth2/token/" 19 ) 20 21 func setup() { 22 rtr = mux.NewRouter() 23 server = httptest.NewServer(rtr) 24 } 25 26 func teardown() { 27 server.Close() 28 } 29 30 func TestMsiHandler_NoMetadataHeader(t *testing.T) { 31 setup() 32 defer teardown() 33 34 s := &Server{ 35 MetadataHeaderRequired: true, 36 } 37 rtr.PathPrefix("/{type:(?i:metadata)}/identity/oauth2/token/").Handler(appHandler(s.msiHandler)) 38 39 req, err := http.NewRequest(http.MethodGet, tokenPath, nil) 40 if err != nil { 41 t.Fatal(err) 42 } 43 44 recorder := httptest.NewRecorder() 45 rtr.ServeHTTP(recorder, req) 46 47 if recorder.Code != http.StatusBadRequest { 48 t.Errorf("Unexpected status code %d", recorder.Code) 49 } 50 51 resp := &MetadataResponse{ 52 Error: "invalid_request", 53 ErrorDescription: "Required metadata header not specified", 54 } 55 expected, err := json.Marshal(resp) 56 if err != nil { 57 t.Fatal(err) 58 } 59 60 if string(expected) != strings.TrimSpace(recorder.Body.String()) { 61 t.Errorf("Unexpected response body %s", recorder.Body.String()) 62 } 63 } 64 65 func TestMsiHandler_NoRemoteAddress(t *testing.T) { 66 setup() 67 defer teardown() 68 69 s := &Server{ 70 MetadataHeaderRequired: false, 71 } 72 rtr.PathPrefix("/{type:(?i:metadata)}/identity/oauth2/token/").Handler(appHandler(s.msiHandler)) 73 74 req, err := http.NewRequest(http.MethodGet, tokenPath, nil) 75 if err != nil { 76 t.Fatal(err) 77 } 78 79 recorder := httptest.NewRecorder() 80 rtr.ServeHTTP(recorder, req) 81 82 if recorder.Code != http.StatusInternalServerError { 83 t.Errorf("Unexpected status code %d", recorder.Code) 84 } 85 86 expected := "request remote address is empty" 87 if expected != strings.TrimSpace(recorder.Body.String()) { 88 t.Errorf("Unexpected response body %s", recorder.Body.String()) 89 } 90 } 91 92 func TestParseTokenRequest(t *testing.T) { 93 const endpoint = "http://127.0.0.1/metadata/identity/oauth2/token" 94 95 t.Run("query present", func(t *testing.T) { 96 const resource = "https://vault.azure.net" 97 const clientID = "77788899-f67e-42e1-9a78-89985f6bff3e" 98 const resourceID = "/subscriptions/9f2be85c-f8ae-4569-9353-38e5e8b459ef/resourcegroups/test/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test" 99 100 var r http.Request 101 r.URL, _ = url.Parse(fmt.Sprintf("%s?client_id=%s&msi_res_id=%s&resource=%s", endpoint, clientID, resourceID, resource)) 102 103 result := parseTokenRequest(&r) 104 105 if result.ClientID != clientID { 106 t.Errorf("invalid ClientID - expected: %q, actual: %q", clientID, result.ClientID) 107 } 108 109 if result.ResourceID != resourceID { 110 t.Errorf("invalid ResourceID - expected: %q, actual: %q", resourceID, result.ResourceID) 111 } 112 113 if result.Resource != resource { 114 t.Errorf("invalid Resource - expected: %q, actual: %q", resource, result.Resource) 115 } 116 }) 117 118 t.Run("query present with latest resource id", func(t *testing.T) { 119 const resource = "https://vault.azure.net" 120 const resourceID = "/subscriptions/9f2be85c-f8ae-4569-9353-38e5e8b459ef/resourcegroups/test/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test" 121 122 var r http.Request 123 r.URL, _ = url.Parse(fmt.Sprintf("%s?mi_res_id=%s&resource=%s", endpoint, resourceID, resource)) 124 125 result := parseTokenRequest(&r) 126 127 if result.ResourceID != resourceID { 128 t.Errorf("invalid ResourceID - expected: %q, actual: %q", resourceID, result.ResourceID) 129 } 130 }) 131 132 t.Run("bare endpoint", func(t *testing.T) { 133 var r http.Request 134 r.URL, _ = url.Parse(endpoint) 135 136 result := parseTokenRequest(&r) 137 138 if result.ClientID != "" { 139 t.Errorf("invalid ClientID - expected: %q, actual: %q", "", result.ClientID) 140 } 141 142 if result.ResourceID != "" { 143 t.Errorf("invalid ResourceID - expected: %q, actual: %q", "", result.ResourceID) 144 } 145 146 if result.Resource != "" { 147 t.Errorf("invalid Resource - expected: %q, actual: %q", "", result.Resource) 148 } 149 }) 150 } 151 152 func TestTokenRequest_ValidateResourceParamExists(t *testing.T) { 153 tr := TokenRequest{ 154 Resource: "https://vault.azure.net", 155 } 156 157 if !tr.ValidateResourceParamExists() { 158 t.Error("ValidateResourceParamExists should have returned true when the resource is set") 159 } 160 161 tr.Resource = "" 162 if tr.ValidateResourceParamExists() { 163 t.Error("ValidateResourceParamExists should have returned false when the resource is unset") 164 } 165 } 166 167 func TestRouterPathPrefix(t *testing.T) { 168 tests := []struct { 169 name string 170 url string 171 expectedStatusCode int 172 expectedBody string 173 }{ 174 { 175 name: "token request", 176 url: "/metadata/identity/oauth2/token/", 177 expectedStatusCode: http.StatusOK, 178 expectedBody: "token_request_handler", 179 }, 180 { 181 name: "token request without / suffix", 182 url: "/metadata/identity/oauth2/token", 183 expectedStatusCode: http.StatusOK, 184 expectedBody: "token_request_handler", 185 }, 186 { 187 name: "token request with upper case metadata", 188 url: "/Metadata/identity/oauth2/token/", 189 expectedStatusCode: http.StatusOK, 190 expectedBody: "token_request_handler", 191 }, 192 { 193 name: "token request with upper case identity", 194 url: "/metadata/Identity/oauth2/token/", 195 expectedStatusCode: http.StatusOK, 196 expectedBody: "default_handler", 197 }, 198 { 199 name: "host token request", 200 url: "/host/token/", 201 expectedStatusCode: http.StatusOK, 202 expectedBody: "host_token_request_handler", 203 }, 204 { 205 name: "host token request without / suffix", 206 url: "/host/token", 207 expectedStatusCode: http.StatusOK, 208 expectedBody: "host_token_request_handler", 209 }, 210 { 211 name: "instance metadata request", 212 url: "/metadata/instance", 213 expectedStatusCode: http.StatusOK, 214 expectedBody: "instance_request_handler", 215 }, 216 { 217 name: "instance metadata request with upper case metadata", 218 url: "/Metadata/instance", 219 expectedStatusCode: http.StatusOK, 220 expectedBody: "instance_request_handler", 221 }, 222 { 223 name: "instance metadata request / suffix", 224 url: "/Metadata/instance/", 225 expectedStatusCode: http.StatusOK, 226 expectedBody: "instance_request_handler", 227 }, 228 { 229 name: "default metadata request", 230 url: "/metadata/", 231 expectedStatusCode: http.StatusOK, 232 expectedBody: "default_handler", 233 }, 234 { 235 name: "invalid token request with \\oauth2", 236 url: `/metadata/identity\oauth2/token/`, 237 expectedStatusCode: http.StatusOK, 238 expectedBody: "invalid_request_handler", 239 }, 240 { 241 name: "invalid token request with \\token", 242 url: `/metadata/identity/oauth2\token/`, 243 expectedStatusCode: http.StatusOK, 244 expectedBody: "invalid_request_handler", 245 }, 246 { 247 name: "invalid token request with \\oauth2\\token", 248 url: `/metadata/identity\oauth2\token/`, 249 expectedStatusCode: http.StatusOK, 250 expectedBody: "invalid_request_handler", 251 }, 252 { 253 name: "invalid token request with mix of / and \\", 254 url: `/metadata/identity/\oauth2\token/`, 255 expectedStatusCode: http.StatusOK, 256 expectedBody: "invalid_request_handler", 257 }, 258 { 259 name: "invalid token request with multiple \\", 260 url: `/metadata/identity\\\oauth2\\token/`, 261 expectedStatusCode: http.StatusOK, 262 expectedBody: "invalid_request_handler", 263 }, 264 } 265 266 for _, test := range tests { 267 t.Run(test.name, func(t *testing.T) { 268 setup() 269 defer teardown() 270 271 rtr.PathPrefix(tokenPathPrefix).HandlerFunc(testTokenHandler) 272 rtr.MatcherFunc(invalidTokenPathMatcher).HandlerFunc(testInvalidRequestHandler) 273 rtr.PathPrefix(hostTokenPathPrefix).HandlerFunc(testHostTokenHandler) 274 rtr.PathPrefix(instancePathPrefix).HandlerFunc(testInstanceHandler) 275 rtr.PathPrefix("/").HandlerFunc(testDefaultHandler) 276 277 req, err := http.NewRequest(http.MethodGet, test.url, nil) 278 if err != nil { 279 t.Fatal(err) 280 } 281 282 recorder := httptest.NewRecorder() 283 rtr.ServeHTTP(recorder, req) 284 285 if recorder.Code != test.expectedStatusCode { 286 t.Errorf("unexpected status code %d", recorder.Code) 287 } 288 289 if test.expectedBody != strings.TrimSpace(recorder.Body.String()) { 290 t.Errorf("unexpected response body %s", recorder.Body.String()) 291 } 292 }) 293 } 294 } 295 296 func testTokenHandler(w http.ResponseWriter, r *http.Request) { 297 fmt.Fprintf(w, "token_request_handler\n") 298 } 299 300 func testHostTokenHandler(w http.ResponseWriter, r *http.Request) { 301 fmt.Fprintf(w, "host_token_request_handler\n") 302 } 303 304 func testInstanceHandler(w http.ResponseWriter, r *http.Request) { 305 fmt.Fprintf(w, "instance_request_handler\n") 306 } 307 308 func testDefaultHandler(w http.ResponseWriter, r *http.Request) { 309 fmt.Fprintf(w, "default_handler\n") 310 } 311 312 func testInvalidRequestHandler(w http.ResponseWriter, r *http.Request) { 313 fmt.Fprintf(w, "invalid_request_handler\n") 314 }