github.com/prysmaticlabs/prysm@v1.4.4/shared/gateway/api_middleware_processing_test.go (about)

     1  package gateway
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/prysmaticlabs/prysm/shared/grpcutils"
    12  	"github.com/prysmaticlabs/prysm/shared/testutil/assert"
    13  	"github.com/prysmaticlabs/prysm/shared/testutil/require"
    14  	"github.com/sirupsen/logrus/hooks/test"
    15  )
    16  
    17  type testRequestContainer struct {
    18  	TestString    string
    19  	TestHexString string `hex:"true"`
    20  }
    21  
    22  func defaultRequestContainer() *testRequestContainer {
    23  	return &testRequestContainer{
    24  		TestString:    "test string",
    25  		TestHexString: "0x666F6F", // hex encoding of "foo"
    26  	}
    27  }
    28  
    29  type testResponseContainer struct {
    30  	TestString string
    31  	TestHex    string `hex:"true"`
    32  	TestEnum   string `enum:"true"`
    33  	TestTime   string `time:"true"`
    34  }
    35  
    36  func defaultResponseContainer() *testResponseContainer {
    37  	return &testResponseContainer{
    38  		TestString: "test string",
    39  		TestHex:    "Zm9v", // base64 encoding of "foo"
    40  		TestEnum:   "Test Enum",
    41  		TestTime:   "2006-01-02T15:04:05Z",
    42  	}
    43  }
    44  
    45  type testErrorJson struct {
    46  	Message     string
    47  	Code        int
    48  	CustomField string
    49  }
    50  
    51  // StatusCode returns the error's underlying error code.
    52  func (e *testErrorJson) StatusCode() int {
    53  	return e.Code
    54  }
    55  
    56  // Msg returns the error's underlying message.
    57  func (e *testErrorJson) Msg() string {
    58  	return e.Message
    59  }
    60  
    61  // SetCode sets the error's underlying error code.
    62  func (e *testErrorJson) SetCode(code int) {
    63  	e.Code = code
    64  }
    65  
    66  func TestDeserializeRequestBodyIntoContainer(t *testing.T) {
    67  	t.Run("ok", func(t *testing.T) {
    68  		var bodyJson bytes.Buffer
    69  		err := json.NewEncoder(&bodyJson).Encode(defaultRequestContainer())
    70  		require.NoError(t, err)
    71  
    72  		container := &testRequestContainer{}
    73  		errJson := DeserializeRequestBodyIntoContainer(&bodyJson, container)
    74  		require.Equal(t, true, errJson == nil)
    75  		assert.Equal(t, "test string", container.TestString)
    76  	})
    77  
    78  	t.Run("error", func(t *testing.T) {
    79  		var bodyJson bytes.Buffer
    80  		bodyJson.Write([]byte("foo"))
    81  		errJson := DeserializeRequestBodyIntoContainer(&bodyJson, &testRequestContainer{})
    82  		require.NotNil(t, errJson)
    83  		assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not decode request body"))
    84  		assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
    85  	})
    86  }
    87  
    88  func TestProcessRequestContainerFields(t *testing.T) {
    89  	t.Run("ok", func(t *testing.T) {
    90  		container := defaultRequestContainer()
    91  
    92  		errJson := ProcessRequestContainerFields(container)
    93  		require.Equal(t, true, errJson == nil)
    94  		assert.Equal(t, "Zm9v", container.TestHexString)
    95  	})
    96  
    97  	t.Run("error", func(t *testing.T) {
    98  		errJson := ProcessRequestContainerFields("foo")
    99  		require.NotNil(t, errJson)
   100  		assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not process request data"))
   101  		assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
   102  	})
   103  }
   104  
   105  func TestSetRequestBodyToRequestContainer(t *testing.T) {
   106  	var body bytes.Buffer
   107  	request := httptest.NewRequest("GET", "http://foo.example", &body)
   108  
   109  	errJson := SetRequestBodyToRequestContainer(defaultRequestContainer(), request)
   110  	require.Equal(t, true, errJson == nil)
   111  	container := &testRequestContainer{}
   112  	require.NoError(t, json.NewDecoder(request.Body).Decode(container))
   113  	assert.Equal(t, "test string", container.TestString)
   114  	contentLengthHeader, ok := request.Header["Content-Length"]
   115  	require.Equal(t, true, ok)
   116  	require.Equal(t, 1, len(contentLengthHeader), "wrong number of header values")
   117  	assert.Equal(t, "55", contentLengthHeader[0])
   118  	assert.Equal(t, int64(55), request.ContentLength)
   119  }
   120  
   121  func TestPrepareRequestForProxying(t *testing.T) {
   122  	middleware := &ApiProxyMiddleware{
   123  		GatewayAddress: "http://gateway.example",
   124  	}
   125  	// We will set some params to make the request more interesting.
   126  	endpoint := Endpoint{
   127  		Path:                  "/{url_param}",
   128  		GetRequestURLLiterals: []string{"url_param"},
   129  		GetRequestQueryParams: []QueryParam{{Name: "query_param"}},
   130  	}
   131  	var body bytes.Buffer
   132  	request := httptest.NewRequest("GET", "http://foo.example?query_param=bar", &body)
   133  
   134  	errJson := middleware.PrepareRequestForProxying(endpoint, request)
   135  	require.Equal(t, true, errJson == nil)
   136  	assert.Equal(t, "http", request.URL.Scheme)
   137  	assert.Equal(t, middleware.GatewayAddress, request.URL.Host)
   138  	assert.Equal(t, "", request.RequestURI)
   139  }
   140  
   141  func TestReadGrpcResponseBody(t *testing.T) {
   142  	var b bytes.Buffer
   143  	b.Write([]byte("foo"))
   144  
   145  	body, jsonErr := ReadGrpcResponseBody(&b)
   146  	require.Equal(t, true, jsonErr == nil)
   147  	assert.Equal(t, "foo", string(body))
   148  }
   149  
   150  func TestDeserializeGrpcResponseBodyIntoErrorJson(t *testing.T) {
   151  	t.Run("ok", func(t *testing.T) {
   152  		e := &testErrorJson{
   153  			Message: "foo",
   154  			Code:    500,
   155  		}
   156  		body, err := json.Marshal(e)
   157  		require.NoError(t, err)
   158  
   159  		eToDeserialize := &testErrorJson{}
   160  		errJson := DeserializeGrpcResponseBodyIntoErrorJson(eToDeserialize, body)
   161  		require.Equal(t, true, errJson == nil)
   162  		assert.Equal(t, "foo", eToDeserialize.Msg())
   163  		assert.Equal(t, 500, eToDeserialize.StatusCode())
   164  	})
   165  
   166  	t.Run("error", func(t *testing.T) {
   167  		errJson := DeserializeGrpcResponseBodyIntoErrorJson(nil, nil)
   168  		require.NotNil(t, errJson)
   169  		assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not unmarshal error"))
   170  	})
   171  }
   172  
   173  func TestHandleGrpcResponseError(t *testing.T) {
   174  	response := &http.Response{
   175  		StatusCode: 400,
   176  		Header: http.Header{
   177  			"Foo": []string{"foo"},
   178  			"Bar": []string{"bar"},
   179  		},
   180  	}
   181  	writer := httptest.NewRecorder()
   182  	errJson := &testErrorJson{
   183  		Message: "foo",
   184  		Code:    500,
   185  	}
   186  
   187  	HandleGrpcResponseError(errJson, response, writer)
   188  	v, ok := writer.Header()["Foo"]
   189  	require.Equal(t, true, ok, "header not found")
   190  	require.Equal(t, 1, len(v), "wrong number of header values")
   191  	assert.Equal(t, "foo", v[0])
   192  	v, ok = writer.Header()["Bar"]
   193  	require.Equal(t, true, ok, "header not found")
   194  	require.Equal(t, 1, len(v), "wrong number of header values")
   195  	assert.Equal(t, "bar", v[0])
   196  	assert.Equal(t, 400, errJson.StatusCode())
   197  }
   198  
   199  func TestGrpcResponseIsStatusCodeOnly(t *testing.T) {
   200  	var body bytes.Buffer
   201  
   202  	t.Run("status_code_only", func(t *testing.T) {
   203  		request := httptest.NewRequest("GET", "http://foo.example", &body)
   204  		result := GrpcResponseIsStatusCodeOnly(request, nil)
   205  		assert.Equal(t, true, result)
   206  	})
   207  
   208  	t.Run("different_method", func(t *testing.T) {
   209  		request := httptest.NewRequest("POST", "http://foo.example", &body)
   210  		result := GrpcResponseIsStatusCodeOnly(request, nil)
   211  		assert.Equal(t, false, result)
   212  	})
   213  
   214  	t.Run("non_empty_response", func(t *testing.T) {
   215  		request := httptest.NewRequest("GET", "http://foo.example", &body)
   216  		result := GrpcResponseIsStatusCodeOnly(request, &testRequestContainer{})
   217  		assert.Equal(t, false, result)
   218  	})
   219  }
   220  
   221  func TestDeserializeGrpcResponseBodyIntoContainer(t *testing.T) {
   222  	t.Run("ok", func(t *testing.T) {
   223  		body, err := json.Marshal(defaultRequestContainer())
   224  		require.NoError(t, err)
   225  
   226  		container := &testRequestContainer{}
   227  		errJson := DeserializeGrpcResponseBodyIntoContainer(body, container)
   228  		require.Equal(t, true, errJson == nil)
   229  		assert.Equal(t, "test string", container.TestString)
   230  	})
   231  
   232  	t.Run("error", func(t *testing.T) {
   233  		var bodyJson bytes.Buffer
   234  		bodyJson.Write([]byte("foo"))
   235  		errJson := DeserializeGrpcResponseBodyIntoContainer(bodyJson.Bytes(), &testRequestContainer{})
   236  		require.NotNil(t, errJson)
   237  		assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not unmarshal response"))
   238  		assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
   239  	})
   240  }
   241  
   242  func TestProcessMiddlewareResponseFields(t *testing.T) {
   243  	t.Run("Ok", func(t *testing.T) {
   244  		container := defaultResponseContainer()
   245  
   246  		errJson := ProcessMiddlewareResponseFields(container)
   247  		require.Equal(t, true, errJson == nil)
   248  		assert.Equal(t, "0x666f6f", container.TestHex)
   249  		assert.Equal(t, "test enum", container.TestEnum)
   250  		assert.Equal(t, "1136214245", container.TestTime)
   251  	})
   252  
   253  	t.Run("error", func(t *testing.T) {
   254  		errJson := ProcessMiddlewareResponseFields("foo")
   255  		require.NotNil(t, errJson)
   256  		assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not process response data"))
   257  		assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
   258  	})
   259  }
   260  
   261  func TestSerializeMiddlewareResponseIntoJson(t *testing.T) {
   262  	container := defaultResponseContainer()
   263  	j, errJson := SerializeMiddlewareResponseIntoJson(container)
   264  	assert.Equal(t, true, errJson == nil)
   265  	cToDeserialize := &testResponseContainer{}
   266  	require.NoError(t, json.Unmarshal(j, cToDeserialize))
   267  	assert.Equal(t, "test string", cToDeserialize.TestString)
   268  }
   269  
   270  func TestWriteMiddlewareResponseHeadersAndBody(t *testing.T) {
   271  	var body bytes.Buffer
   272  
   273  	t.Run("GET", func(t *testing.T) {
   274  		request := httptest.NewRequest("GET", "http://foo.example", &body)
   275  		response := &http.Response{
   276  			Header: http.Header{
   277  				"Foo": []string{"foo"},
   278  				"Grpc-Metadata-" + grpcutils.HttpCodeMetadataKey: []string{"204"},
   279  			},
   280  		}
   281  		container := defaultResponseContainer()
   282  		responseJson, err := json.Marshal(container)
   283  		require.NoError(t, err)
   284  		writer := httptest.NewRecorder()
   285  		writer.Body = &bytes.Buffer{}
   286  
   287  		errJson := WriteMiddlewareResponseHeadersAndBody(request, response, responseJson, writer)
   288  		require.Equal(t, true, errJson == nil)
   289  		v, ok := writer.Header()["Foo"]
   290  		require.Equal(t, true, ok, "header not found")
   291  		require.Equal(t, 1, len(v), "wrong number of header values")
   292  		assert.Equal(t, "foo", v[0])
   293  		v, ok = writer.Header()["Content-Length"]
   294  		require.Equal(t, true, ok, "header not found")
   295  		require.Equal(t, 1, len(v), "wrong number of header values")
   296  		assert.Equal(t, "102", v[0])
   297  		assert.Equal(t, 204, writer.Code)
   298  		assert.DeepEqual(t, responseJson, writer.Body.Bytes())
   299  	})
   300  
   301  	t.Run("GET_no_grpc_status_code_header", func(t *testing.T) {
   302  		request := httptest.NewRequest("GET", "http://foo.example", &body)
   303  		response := &http.Response{
   304  			Header:     http.Header{},
   305  			StatusCode: 204,
   306  		}
   307  		container := defaultResponseContainer()
   308  		responseJson, err := json.Marshal(container)
   309  		require.NoError(t, err)
   310  		writer := httptest.NewRecorder()
   311  
   312  		errJson := WriteMiddlewareResponseHeadersAndBody(request, response, responseJson, writer)
   313  		require.Equal(t, true, errJson == nil)
   314  		assert.Equal(t, 204, writer.Code)
   315  	})
   316  
   317  	t.Run("GET_invalid_status_code", func(t *testing.T) {
   318  		request := httptest.NewRequest("GET", "http://foo.example", &body)
   319  		response := &http.Response{
   320  			Header: http.Header{},
   321  		}
   322  
   323  		// Set invalid status code.
   324  		response.Header["Grpc-Metadata-"+grpcutils.HttpCodeMetadataKey] = []string{"invalid"}
   325  
   326  		container := defaultResponseContainer()
   327  		responseJson, err := json.Marshal(container)
   328  		require.NoError(t, err)
   329  		writer := httptest.NewRecorder()
   330  
   331  		errJson := WriteMiddlewareResponseHeadersAndBody(request, response, responseJson, writer)
   332  		require.Equal(t, false, errJson == nil)
   333  		assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not parse status code"))
   334  		assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
   335  	})
   336  
   337  	t.Run("POST", func(t *testing.T) {
   338  		request := httptest.NewRequest("POST", "http://foo.example", &body)
   339  		response := &http.Response{
   340  			Header:     http.Header{},
   341  			StatusCode: 204,
   342  		}
   343  		container := defaultResponseContainer()
   344  		responseJson, err := json.Marshal(container)
   345  		require.NoError(t, err)
   346  		writer := httptest.NewRecorder()
   347  
   348  		errJson := WriteMiddlewareResponseHeadersAndBody(request, response, responseJson, writer)
   349  		require.Equal(t, true, errJson == nil)
   350  		assert.Equal(t, 204, writer.Code)
   351  	})
   352  }
   353  
   354  func TestWriteError(t *testing.T) {
   355  	t.Run("ok", func(t *testing.T) {
   356  		responseHeader := http.Header{
   357  			"Grpc-Metadata-" + grpcutils.CustomErrorMetadataKey: []string{"{\"CustomField\":\"bar\"}"},
   358  		}
   359  		errJson := &testErrorJson{
   360  			Message: "foo",
   361  			Code:    500,
   362  		}
   363  		writer := httptest.NewRecorder()
   364  		writer.Body = &bytes.Buffer{}
   365  
   366  		WriteError(writer, errJson, responseHeader)
   367  		v, ok := writer.Header()["Content-Length"]
   368  		require.Equal(t, true, ok, "header not found")
   369  		require.Equal(t, 1, len(v), "wrong number of header values")
   370  		assert.Equal(t, "48", v[0])
   371  		v, ok = writer.Header()["Content-Type"]
   372  		require.Equal(t, true, ok, "header not found")
   373  		require.Equal(t, 1, len(v), "wrong number of header values")
   374  		assert.Equal(t, "application/json", v[0])
   375  		assert.Equal(t, 500, writer.Code)
   376  		eDeserialize := &testErrorJson{}
   377  		require.NoError(t, json.Unmarshal(writer.Body.Bytes(), eDeserialize))
   378  		assert.Equal(t, "foo", eDeserialize.Message)
   379  		assert.Equal(t, 500, eDeserialize.Code)
   380  		assert.Equal(t, "bar", eDeserialize.CustomField)
   381  	})
   382  
   383  	t.Run("invalid_custom_error_header", func(t *testing.T) {
   384  		logHook := test.NewGlobal()
   385  
   386  		responseHeader := http.Header{
   387  			"Grpc-Metadata-" + grpcutils.CustomErrorMetadataKey: []string{"invalid"},
   388  		}
   389  
   390  		WriteError(httptest.NewRecorder(), &testErrorJson{}, responseHeader)
   391  		assert.LogsContain(t, logHook, "Could not unmarshal custom error message")
   392  	})
   393  }