github.com/prysmaticlabs/prysm@v1.4.4/shared/gateway/gateway_test.go (about)

     1  package gateway
     2  
     3  import (
     4  	"context"
     5  	"flag"
     6  	"fmt"
     7  	"net/http"
     8  	"testing"
     9  
    10  	"github.com/prysmaticlabs/prysm/cmd/beacon-chain/flags"
    11  	"github.com/prysmaticlabs/prysm/shared/testutil/assert"
    12  	"github.com/prysmaticlabs/prysm/shared/testutil/require"
    13  	logTest "github.com/sirupsen/logrus/hooks/test"
    14  	"github.com/urfave/cli/v2"
    15  )
    16  
    17  type mockEndpointFactory struct {
    18  }
    19  
    20  func (*mockEndpointFactory) Paths() []string {
    21  	return []string{}
    22  }
    23  
    24  func (*mockEndpointFactory) Create(_ string) (*Endpoint, error) {
    25  	return nil, nil
    26  }
    27  
    28  func (*mockEndpointFactory) IsNil() bool {
    29  	return false
    30  }
    31  
    32  func TestGateway_Customized(t *testing.T) {
    33  	mux := http.NewServeMux()
    34  	cert := "cert"
    35  	origins := []string{"origin"}
    36  	size := uint64(100)
    37  	middlewareAddr := "middleware"
    38  	endpointFactory := &mockEndpointFactory{}
    39  
    40  	g := New(
    41  		context.Background(),
    42  		[]PbMux{},
    43  		func(handler http.Handler, writer http.ResponseWriter, request *http.Request) {
    44  
    45  		},
    46  		"",
    47  		"",
    48  	).WithMux(mux).
    49  		WithRemoteCert(cert).
    50  		WithAllowedOrigins(origins).
    51  		WithMaxCallRecvMsgSize(size).
    52  		WithApiMiddleware(middlewareAddr, endpointFactory)
    53  
    54  	assert.Equal(t, mux, g.mux)
    55  	assert.Equal(t, cert, g.remoteCert)
    56  	require.Equal(t, 1, len(g.allowedOrigins))
    57  	assert.Equal(t, origins[0], g.allowedOrigins[0])
    58  	assert.Equal(t, size, g.maxCallRecvMsgSize)
    59  	assert.Equal(t, middlewareAddr, g.apiMiddlewareAddr)
    60  	assert.Equal(t, endpointFactory, g.apiMiddlewareEndpointFactory)
    61  }
    62  
    63  func TestGateway_StartStop(t *testing.T) {
    64  	hook := logTest.NewGlobal()
    65  
    66  	app := cli.App{}
    67  	set := flag.NewFlagSet("test", 0)
    68  	ctx := cli.NewContext(&app, set, nil)
    69  
    70  	gatewayPort := ctx.Int(flags.GRPCGatewayPort.Name)
    71  	gatewayHost := ctx.String(flags.GRPCGatewayHost.Name)
    72  	rpcHost := ctx.String(flags.RPCHost.Name)
    73  	selfAddress := fmt.Sprintf("%s:%d", rpcHost, ctx.Int(flags.RPCPort.Name))
    74  	gatewayAddress := fmt.Sprintf("%s:%d", gatewayHost, gatewayPort)
    75  
    76  	g := New(
    77  		ctx.Context,
    78  		[]PbMux{},
    79  		func(handler http.Handler, writer http.ResponseWriter, request *http.Request) {
    80  
    81  		},
    82  		selfAddress,
    83  		gatewayAddress,
    84  	)
    85  
    86  	g.Start()
    87  	go func() {
    88  		require.LogsContain(t, hook, "Starting gRPC gateway")
    89  		require.LogsDoNotContain(t, hook, "Starting API middleware")
    90  	}()
    91  
    92  	err := g.Stop()
    93  	require.NoError(t, err)
    94  }