github.com/prysmaticlabs/prysm@v1.4.4/beacon-chain/rpc/apimiddleware/custom_handlers_test.go (about)

     1  package apimiddleware
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/prysmaticlabs/prysm/beacon-chain/rpc/eth/v1/events"
    14  	"github.com/prysmaticlabs/prysm/shared/gateway"
    15  	"github.com/prysmaticlabs/prysm/shared/grpcutils"
    16  	"github.com/prysmaticlabs/prysm/shared/testutil/assert"
    17  	"github.com/prysmaticlabs/prysm/shared/testutil/require"
    18  	"github.com/r3labs/sse"
    19  )
    20  
    21  func TestSSZRequested(t *testing.T) {
    22  	t.Run("ssz_requested", func(t *testing.T) {
    23  		request := httptest.NewRequest("GET", "http://foo.example", nil)
    24  		request.Header["Accept"] = []string{"application/octet-stream"}
    25  		result := sszRequested(request)
    26  		assert.Equal(t, true, result)
    27  	})
    28  
    29  	t.Run("multiple_content_types", func(t *testing.T) {
    30  		request := httptest.NewRequest("GET", "http://foo.example", nil)
    31  		request.Header["Accept"] = []string{"application/json", "application/octet-stream"}
    32  		result := sszRequested(request)
    33  		assert.Equal(t, true, result)
    34  	})
    35  
    36  	t.Run("no_header", func(t *testing.T) {
    37  		request := httptest.NewRequest("GET", "http://foo.example", nil)
    38  		result := sszRequested(request)
    39  		assert.Equal(t, false, result)
    40  	})
    41  
    42  	t.Run("other_content_type", func(t *testing.T) {
    43  		request := httptest.NewRequest("GET", "http://foo.example", nil)
    44  		request.Header["Accept"] = []string{"application/json"}
    45  		result := sszRequested(request)
    46  		assert.Equal(t, false, result)
    47  	})
    48  }
    49  
    50  func TestPrepareSSZRequestForProxying(t *testing.T) {
    51  	middleware := &gateway.ApiProxyMiddleware{
    52  		GatewayAddress: "http://gateway.example",
    53  	}
    54  	endpoint := gateway.Endpoint{
    55  		Path: "http://foo.example",
    56  	}
    57  	var body bytes.Buffer
    58  	request := httptest.NewRequest("GET", "http://foo.example", &body)
    59  
    60  	errJson := prepareSSZRequestForProxying(middleware, endpoint, request, "/ssz")
    61  	require.Equal(t, true, errJson == nil)
    62  	assert.Equal(t, "/ssz", request.URL.Path)
    63  }
    64  
    65  func TestSerializeMiddlewareResponseIntoSSZ(t *testing.T) {
    66  	t.Run("ok", func(t *testing.T) {
    67  		ssz, errJson := serializeMiddlewareResponseIntoSSZ("Zm9v")
    68  		require.Equal(t, true, errJson == nil)
    69  		assert.DeepEqual(t, []byte("foo"), ssz)
    70  	})
    71  
    72  	t.Run("invalid_data", func(t *testing.T) {
    73  		_, errJson := serializeMiddlewareResponseIntoSSZ("invalid")
    74  		require.Equal(t, false, errJson == nil)
    75  		assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not decode response body into base64"))
    76  		assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
    77  	})
    78  }
    79  
    80  func TestWriteSSZResponseHeaderAndBody(t *testing.T) {
    81  	t.Run("ok", func(t *testing.T) {
    82  		response := &http.Response{
    83  			Header: http.Header{
    84  				"Foo": []string{"foo"},
    85  				"Grpc-Metadata-" + grpcutils.HttpCodeMetadataKey: []string{"204"},
    86  			},
    87  		}
    88  		responseSsz := []byte("ssz")
    89  		writer := httptest.NewRecorder()
    90  		writer.Body = &bytes.Buffer{}
    91  
    92  		errJson := writeSSZResponseHeaderAndBody(response, writer, responseSsz, "test.ssz")
    93  		require.Equal(t, true, errJson == nil)
    94  		v, ok := writer.Header()["Foo"]
    95  		require.Equal(t, true, ok, "header not found")
    96  		require.Equal(t, 1, len(v), "wrong number of header values")
    97  		assert.Equal(t, "foo", v[0])
    98  		v, ok = writer.Header()["Content-Length"]
    99  		require.Equal(t, true, ok, "header not found")
   100  		require.Equal(t, 1, len(v), "wrong number of header values")
   101  		assert.Equal(t, "3", v[0])
   102  		v, ok = writer.Header()["Content-Type"]
   103  		require.Equal(t, true, ok, "header not found")
   104  		require.Equal(t, 1, len(v), "wrong number of header values")
   105  		assert.Equal(t, "application/octet-stream", v[0])
   106  		v, ok = writer.Header()["Content-Disposition"]
   107  		require.Equal(t, true, ok, "header not found")
   108  		require.Equal(t, 1, len(v), "wrong number of header values")
   109  		assert.Equal(t, "attachment; filename=test.ssz", v[0])
   110  		assert.Equal(t, 204, writer.Code)
   111  	})
   112  
   113  	t.Run("no_grpc_status_code_header", func(t *testing.T) {
   114  		response := &http.Response{
   115  			Header:     http.Header{},
   116  			StatusCode: 204,
   117  		}
   118  		responseSsz := []byte("ssz")
   119  		writer := httptest.NewRecorder()
   120  		writer.Body = &bytes.Buffer{}
   121  
   122  		errJson := writeSSZResponseHeaderAndBody(response, writer, responseSsz, "test.ssz")
   123  		require.Equal(t, true, errJson == nil)
   124  		assert.Equal(t, 204, writer.Code)
   125  	})
   126  
   127  	t.Run("invalid_status_code", func(t *testing.T) {
   128  		response := &http.Response{
   129  			Header: http.Header{
   130  				"Foo": []string{"foo"},
   131  				"Grpc-Metadata-" + grpcutils.HttpCodeMetadataKey: []string{"invalid"},
   132  			},
   133  		}
   134  		responseSsz := []byte("ssz")
   135  		writer := httptest.NewRecorder()
   136  		writer.Body = &bytes.Buffer{}
   137  
   138  		errJson := writeSSZResponseHeaderAndBody(response, writer, responseSsz, "test.ssz")
   139  		require.Equal(t, false, errJson == nil)
   140  		assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not parse status code"))
   141  		assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
   142  	})
   143  }
   144  
   145  func TestReceiveEvents(t *testing.T) {
   146  	ctx, cancel := context.WithCancel(context.Background())
   147  	ch := make(chan *sse.Event)
   148  	w := httptest.NewRecorder()
   149  	w.Body = &bytes.Buffer{}
   150  	req := httptest.NewRequest("GET", "http://foo.example", &bytes.Buffer{})
   151  	req = req.WithContext(ctx)
   152  
   153  	go func() {
   154  		base64Val := "Zm9v"
   155  		data := &eventFinalizedCheckpointJson{
   156  			Block: base64Val,
   157  			State: base64Val,
   158  			Epoch: "1",
   159  		}
   160  		bData, err := json.Marshal(data)
   161  		require.NoError(t, err)
   162  		msg := &sse.Event{
   163  			Data:  bData,
   164  			Event: []byte(events.FinalizedCheckpointTopic),
   165  		}
   166  		ch <- msg
   167  		time.Sleep(time.Second)
   168  		cancel()
   169  	}()
   170  
   171  	errJson := receiveEvents(ch, w, req)
   172  	assert.Equal(t, true, errJson == nil)
   173  }
   174  
   175  func TestReceiveEvents_EventNotSupported(t *testing.T) {
   176  	ch := make(chan *sse.Event)
   177  	w := httptest.NewRecorder()
   178  	w.Body = &bytes.Buffer{}
   179  	req := httptest.NewRequest("GET", "http://foo.example", &bytes.Buffer{})
   180  
   181  	go func() {
   182  		msg := &sse.Event{
   183  			Data:  []byte("foo"),
   184  			Event: []byte("not_supported"),
   185  		}
   186  		ch <- msg
   187  	}()
   188  
   189  	errJson := receiveEvents(ch, w, req)
   190  	require.NotNil(t, errJson)
   191  	assert.Equal(t, "Event type 'not_supported' not supported", errJson.Msg())
   192  }
   193  
   194  func TestWriteEvent(t *testing.T) {
   195  	base64Val := "Zm9v"
   196  	data := &eventFinalizedCheckpointJson{
   197  		Block: base64Val,
   198  		State: base64Val,
   199  		Epoch: "1",
   200  	}
   201  	bData, err := json.Marshal(data)
   202  	require.NoError(t, err)
   203  	msg := &sse.Event{
   204  		Data:  bData,
   205  		Event: []byte("test_event"),
   206  	}
   207  	w := httptest.NewRecorder()
   208  	w.Body = &bytes.Buffer{}
   209  
   210  	errJson := writeEvent(msg, w, &eventFinalizedCheckpointJson{})
   211  	require.Equal(t, true, errJson == nil)
   212  	written := w.Body.String()
   213  	assert.Equal(t, "event: test_event\ndata: {\"block\":\"0x666f6f\",\"state\":\"0x666f6f\",\"epoch\":\"1\"}\n\n", written)
   214  }