github.com/prysmaticlabs/prysm@v1.4.4/shared/gateway/api_middleware_processing_test.go (about) 1 package gateway 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "net/http" 7 "net/http/httptest" 8 "strings" 9 "testing" 10 11 "github.com/prysmaticlabs/prysm/shared/grpcutils" 12 "github.com/prysmaticlabs/prysm/shared/testutil/assert" 13 "github.com/prysmaticlabs/prysm/shared/testutil/require" 14 "github.com/sirupsen/logrus/hooks/test" 15 ) 16 17 type testRequestContainer struct { 18 TestString string 19 TestHexString string `hex:"true"` 20 } 21 22 func defaultRequestContainer() *testRequestContainer { 23 return &testRequestContainer{ 24 TestString: "test string", 25 TestHexString: "0x666F6F", // hex encoding of "foo" 26 } 27 } 28 29 type testResponseContainer struct { 30 TestString string 31 TestHex string `hex:"true"` 32 TestEnum string `enum:"true"` 33 TestTime string `time:"true"` 34 } 35 36 func defaultResponseContainer() *testResponseContainer { 37 return &testResponseContainer{ 38 TestString: "test string", 39 TestHex: "Zm9v", // base64 encoding of "foo" 40 TestEnum: "Test Enum", 41 TestTime: "2006-01-02T15:04:05Z", 42 } 43 } 44 45 type testErrorJson struct { 46 Message string 47 Code int 48 CustomField string 49 } 50 51 // StatusCode returns the error's underlying error code. 52 func (e *testErrorJson) StatusCode() int { 53 return e.Code 54 } 55 56 // Msg returns the error's underlying message. 57 func (e *testErrorJson) Msg() string { 58 return e.Message 59 } 60 61 // SetCode sets the error's underlying error code. 62 func (e *testErrorJson) SetCode(code int) { 63 e.Code = code 64 } 65 66 func TestDeserializeRequestBodyIntoContainer(t *testing.T) { 67 t.Run("ok", func(t *testing.T) { 68 var bodyJson bytes.Buffer 69 err := json.NewEncoder(&bodyJson).Encode(defaultRequestContainer()) 70 require.NoError(t, err) 71 72 container := &testRequestContainer{} 73 errJson := DeserializeRequestBodyIntoContainer(&bodyJson, container) 74 require.Equal(t, true, errJson == nil) 75 assert.Equal(t, "test string", container.TestString) 76 }) 77 78 t.Run("error", func(t *testing.T) { 79 var bodyJson bytes.Buffer 80 bodyJson.Write([]byte("foo")) 81 errJson := DeserializeRequestBodyIntoContainer(&bodyJson, &testRequestContainer{}) 82 require.NotNil(t, errJson) 83 assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not decode request body")) 84 assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode()) 85 }) 86 } 87 88 func TestProcessRequestContainerFields(t *testing.T) { 89 t.Run("ok", func(t *testing.T) { 90 container := defaultRequestContainer() 91 92 errJson := ProcessRequestContainerFields(container) 93 require.Equal(t, true, errJson == nil) 94 assert.Equal(t, "Zm9v", container.TestHexString) 95 }) 96 97 t.Run("error", func(t *testing.T) { 98 errJson := ProcessRequestContainerFields("foo") 99 require.NotNil(t, errJson) 100 assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not process request data")) 101 assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode()) 102 }) 103 } 104 105 func TestSetRequestBodyToRequestContainer(t *testing.T) { 106 var body bytes.Buffer 107 request := httptest.NewRequest("GET", "http://foo.example", &body) 108 109 errJson := SetRequestBodyToRequestContainer(defaultRequestContainer(), request) 110 require.Equal(t, true, errJson == nil) 111 container := &testRequestContainer{} 112 require.NoError(t, json.NewDecoder(request.Body).Decode(container)) 113 assert.Equal(t, "test string", container.TestString) 114 contentLengthHeader, ok := request.Header["Content-Length"] 115 require.Equal(t, true, ok) 116 require.Equal(t, 1, len(contentLengthHeader), "wrong number of header values") 117 assert.Equal(t, "55", contentLengthHeader[0]) 118 assert.Equal(t, int64(55), request.ContentLength) 119 } 120 121 func TestPrepareRequestForProxying(t *testing.T) { 122 middleware := &ApiProxyMiddleware{ 123 GatewayAddress: "http://gateway.example", 124 } 125 // We will set some params to make the request more interesting. 126 endpoint := Endpoint{ 127 Path: "/{url_param}", 128 GetRequestURLLiterals: []string{"url_param"}, 129 GetRequestQueryParams: []QueryParam{{Name: "query_param"}}, 130 } 131 var body bytes.Buffer 132 request := httptest.NewRequest("GET", "http://foo.example?query_param=bar", &body) 133 134 errJson := middleware.PrepareRequestForProxying(endpoint, request) 135 require.Equal(t, true, errJson == nil) 136 assert.Equal(t, "http", request.URL.Scheme) 137 assert.Equal(t, middleware.GatewayAddress, request.URL.Host) 138 assert.Equal(t, "", request.RequestURI) 139 } 140 141 func TestReadGrpcResponseBody(t *testing.T) { 142 var b bytes.Buffer 143 b.Write([]byte("foo")) 144 145 body, jsonErr := ReadGrpcResponseBody(&b) 146 require.Equal(t, true, jsonErr == nil) 147 assert.Equal(t, "foo", string(body)) 148 } 149 150 func TestDeserializeGrpcResponseBodyIntoErrorJson(t *testing.T) { 151 t.Run("ok", func(t *testing.T) { 152 e := &testErrorJson{ 153 Message: "foo", 154 Code: 500, 155 } 156 body, err := json.Marshal(e) 157 require.NoError(t, err) 158 159 eToDeserialize := &testErrorJson{} 160 errJson := DeserializeGrpcResponseBodyIntoErrorJson(eToDeserialize, body) 161 require.Equal(t, true, errJson == nil) 162 assert.Equal(t, "foo", eToDeserialize.Msg()) 163 assert.Equal(t, 500, eToDeserialize.StatusCode()) 164 }) 165 166 t.Run("error", func(t *testing.T) { 167 errJson := DeserializeGrpcResponseBodyIntoErrorJson(nil, nil) 168 require.NotNil(t, errJson) 169 assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not unmarshal error")) 170 }) 171 } 172 173 func TestHandleGrpcResponseError(t *testing.T) { 174 response := &http.Response{ 175 StatusCode: 400, 176 Header: http.Header{ 177 "Foo": []string{"foo"}, 178 "Bar": []string{"bar"}, 179 }, 180 } 181 writer := httptest.NewRecorder() 182 errJson := &testErrorJson{ 183 Message: "foo", 184 Code: 500, 185 } 186 187 HandleGrpcResponseError(errJson, response, writer) 188 v, ok := writer.Header()["Foo"] 189 require.Equal(t, true, ok, "header not found") 190 require.Equal(t, 1, len(v), "wrong number of header values") 191 assert.Equal(t, "foo", v[0]) 192 v, ok = writer.Header()["Bar"] 193 require.Equal(t, true, ok, "header not found") 194 require.Equal(t, 1, len(v), "wrong number of header values") 195 assert.Equal(t, "bar", v[0]) 196 assert.Equal(t, 400, errJson.StatusCode()) 197 } 198 199 func TestGrpcResponseIsStatusCodeOnly(t *testing.T) { 200 var body bytes.Buffer 201 202 t.Run("status_code_only", func(t *testing.T) { 203 request := httptest.NewRequest("GET", "http://foo.example", &body) 204 result := GrpcResponseIsStatusCodeOnly(request, nil) 205 assert.Equal(t, true, result) 206 }) 207 208 t.Run("different_method", func(t *testing.T) { 209 request := httptest.NewRequest("POST", "http://foo.example", &body) 210 result := GrpcResponseIsStatusCodeOnly(request, nil) 211 assert.Equal(t, false, result) 212 }) 213 214 t.Run("non_empty_response", func(t *testing.T) { 215 request := httptest.NewRequest("GET", "http://foo.example", &body) 216 result := GrpcResponseIsStatusCodeOnly(request, &testRequestContainer{}) 217 assert.Equal(t, false, result) 218 }) 219 } 220 221 func TestDeserializeGrpcResponseBodyIntoContainer(t *testing.T) { 222 t.Run("ok", func(t *testing.T) { 223 body, err := json.Marshal(defaultRequestContainer()) 224 require.NoError(t, err) 225 226 container := &testRequestContainer{} 227 errJson := DeserializeGrpcResponseBodyIntoContainer(body, container) 228 require.Equal(t, true, errJson == nil) 229 assert.Equal(t, "test string", container.TestString) 230 }) 231 232 t.Run("error", func(t *testing.T) { 233 var bodyJson bytes.Buffer 234 bodyJson.Write([]byte("foo")) 235 errJson := DeserializeGrpcResponseBodyIntoContainer(bodyJson.Bytes(), &testRequestContainer{}) 236 require.NotNil(t, errJson) 237 assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not unmarshal response")) 238 assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode()) 239 }) 240 } 241 242 func TestProcessMiddlewareResponseFields(t *testing.T) { 243 t.Run("Ok", func(t *testing.T) { 244 container := defaultResponseContainer() 245 246 errJson := ProcessMiddlewareResponseFields(container) 247 require.Equal(t, true, errJson == nil) 248 assert.Equal(t, "0x666f6f", container.TestHex) 249 assert.Equal(t, "test enum", container.TestEnum) 250 assert.Equal(t, "1136214245", container.TestTime) 251 }) 252 253 t.Run("error", func(t *testing.T) { 254 errJson := ProcessMiddlewareResponseFields("foo") 255 require.NotNil(t, errJson) 256 assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not process response data")) 257 assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode()) 258 }) 259 } 260 261 func TestSerializeMiddlewareResponseIntoJson(t *testing.T) { 262 container := defaultResponseContainer() 263 j, errJson := SerializeMiddlewareResponseIntoJson(container) 264 assert.Equal(t, true, errJson == nil) 265 cToDeserialize := &testResponseContainer{} 266 require.NoError(t, json.Unmarshal(j, cToDeserialize)) 267 assert.Equal(t, "test string", cToDeserialize.TestString) 268 } 269 270 func TestWriteMiddlewareResponseHeadersAndBody(t *testing.T) { 271 var body bytes.Buffer 272 273 t.Run("GET", func(t *testing.T) { 274 request := httptest.NewRequest("GET", "http://foo.example", &body) 275 response := &http.Response{ 276 Header: http.Header{ 277 "Foo": []string{"foo"}, 278 "Grpc-Metadata-" + grpcutils.HttpCodeMetadataKey: []string{"204"}, 279 }, 280 } 281 container := defaultResponseContainer() 282 responseJson, err := json.Marshal(container) 283 require.NoError(t, err) 284 writer := httptest.NewRecorder() 285 writer.Body = &bytes.Buffer{} 286 287 errJson := WriteMiddlewareResponseHeadersAndBody(request, response, responseJson, writer) 288 require.Equal(t, true, errJson == nil) 289 v, ok := writer.Header()["Foo"] 290 require.Equal(t, true, ok, "header not found") 291 require.Equal(t, 1, len(v), "wrong number of header values") 292 assert.Equal(t, "foo", v[0]) 293 v, ok = writer.Header()["Content-Length"] 294 require.Equal(t, true, ok, "header not found") 295 require.Equal(t, 1, len(v), "wrong number of header values") 296 assert.Equal(t, "102", v[0]) 297 assert.Equal(t, 204, writer.Code) 298 assert.DeepEqual(t, responseJson, writer.Body.Bytes()) 299 }) 300 301 t.Run("GET_no_grpc_status_code_header", func(t *testing.T) { 302 request := httptest.NewRequest("GET", "http://foo.example", &body) 303 response := &http.Response{ 304 Header: http.Header{}, 305 StatusCode: 204, 306 } 307 container := defaultResponseContainer() 308 responseJson, err := json.Marshal(container) 309 require.NoError(t, err) 310 writer := httptest.NewRecorder() 311 312 errJson := WriteMiddlewareResponseHeadersAndBody(request, response, responseJson, writer) 313 require.Equal(t, true, errJson == nil) 314 assert.Equal(t, 204, writer.Code) 315 }) 316 317 t.Run("GET_invalid_status_code", func(t *testing.T) { 318 request := httptest.NewRequest("GET", "http://foo.example", &body) 319 response := &http.Response{ 320 Header: http.Header{}, 321 } 322 323 // Set invalid status code. 324 response.Header["Grpc-Metadata-"+grpcutils.HttpCodeMetadataKey] = []string{"invalid"} 325 326 container := defaultResponseContainer() 327 responseJson, err := json.Marshal(container) 328 require.NoError(t, err) 329 writer := httptest.NewRecorder() 330 331 errJson := WriteMiddlewareResponseHeadersAndBody(request, response, responseJson, writer) 332 require.Equal(t, false, errJson == nil) 333 assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not parse status code")) 334 assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode()) 335 }) 336 337 t.Run("POST", func(t *testing.T) { 338 request := httptest.NewRequest("POST", "http://foo.example", &body) 339 response := &http.Response{ 340 Header: http.Header{}, 341 StatusCode: 204, 342 } 343 container := defaultResponseContainer() 344 responseJson, err := json.Marshal(container) 345 require.NoError(t, err) 346 writer := httptest.NewRecorder() 347 348 errJson := WriteMiddlewareResponseHeadersAndBody(request, response, responseJson, writer) 349 require.Equal(t, true, errJson == nil) 350 assert.Equal(t, 204, writer.Code) 351 }) 352 } 353 354 func TestWriteError(t *testing.T) { 355 t.Run("ok", func(t *testing.T) { 356 responseHeader := http.Header{ 357 "Grpc-Metadata-" + grpcutils.CustomErrorMetadataKey: []string{"{\"CustomField\":\"bar\"}"}, 358 } 359 errJson := &testErrorJson{ 360 Message: "foo", 361 Code: 500, 362 } 363 writer := httptest.NewRecorder() 364 writer.Body = &bytes.Buffer{} 365 366 WriteError(writer, errJson, responseHeader) 367 v, ok := writer.Header()["Content-Length"] 368 require.Equal(t, true, ok, "header not found") 369 require.Equal(t, 1, len(v), "wrong number of header values") 370 assert.Equal(t, "48", v[0]) 371 v, ok = writer.Header()["Content-Type"] 372 require.Equal(t, true, ok, "header not found") 373 require.Equal(t, 1, len(v), "wrong number of header values") 374 assert.Equal(t, "application/json", v[0]) 375 assert.Equal(t, 500, writer.Code) 376 eDeserialize := &testErrorJson{} 377 require.NoError(t, json.Unmarshal(writer.Body.Bytes(), eDeserialize)) 378 assert.Equal(t, "foo", eDeserialize.Message) 379 assert.Equal(t, 500, eDeserialize.Code) 380 assert.Equal(t, "bar", eDeserialize.CustomField) 381 }) 382 383 t.Run("invalid_custom_error_header", func(t *testing.T) { 384 logHook := test.NewGlobal() 385 386 responseHeader := http.Header{ 387 "Grpc-Metadata-" + grpcutils.CustomErrorMetadataKey: []string{"invalid"}, 388 } 389 390 WriteError(httptest.NewRecorder(), &testErrorJson{}, responseHeader) 391 assert.LogsContain(t, logHook, "Could not unmarshal custom error message") 392 }) 393 }