github.com/gnolang/gno@v0.0.0-20240520182011-228e9d0192ce/tm2/pkg/bft/rpc/lib/server/handlers_test.go (about)

     1  package rpcserver_test
     2  
     3  import (
     4  	"encoding/json"
     5  	"io"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/gorilla/websocket"
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  
    15  	rs "github.com/gnolang/gno/tm2/pkg/bft/rpc/lib/server"
    16  	types "github.com/gnolang/gno/tm2/pkg/bft/rpc/lib/types"
    17  	"github.com/gnolang/gno/tm2/pkg/log"
    18  )
    19  
    20  // -----------
    21  // HTTP REST API
    22  // TODO
    23  
    24  // -----------
    25  // JSON-RPC over HTTP
    26  
    27  func testMux() *http.ServeMux {
    28  	funcMap := map[string]*rs.RPCFunc{
    29  		"c": rs.NewRPCFunc(func(ctx *types.Context, s string, i int) (string, error) { return "foo", nil }, "s,i"),
    30  	}
    31  	mux := http.NewServeMux()
    32  
    33  	rs.RegisterRPCFuncs(mux, funcMap, log.NewNoopLogger())
    34  
    35  	return mux
    36  }
    37  
    38  func statusOK(code int) bool { return code >= 200 && code <= 299 }
    39  
    40  // Ensure that nefarious/unintended inputs to `params`
    41  // do not crash our RPC handlers.
    42  // See Issue https://github.com/gnolang/gno/tm2/pkg/bft/issues/708.
    43  func TestRPCParams(t *testing.T) {
    44  	t.Parallel()
    45  
    46  	mux := testMux()
    47  	tests := []struct {
    48  		payload    string
    49  		wantErr    string
    50  		expectedId interface{}
    51  	}{
    52  		// bad
    53  		{`{"jsonrpc": "2.0", "id": "0"}`, "Method not found", types.JSONRPCStringID("0")},
    54  		{`{"jsonrpc": "2.0", "method": "y", "id": "0"}`, "Method not found", types.JSONRPCStringID("0")},
    55  		{`{"method": "c", "id": "0", "params": a}`, "invalid character", types.JSONRPCStringID("")}, // id not captured in JSON parsing failures
    56  		{`{"method": "c", "id": "0", "params": ["a"]}`, "got 1", types.JSONRPCStringID("0")},
    57  		{`{"method": "c", "id": "0", "params": ["a", "b"]}`, "invalid character", types.JSONRPCStringID("0")},
    58  		{`{"method": "c", "id": "0", "params": [1, 1]}`, "of type string", types.JSONRPCStringID("0")},
    59  
    60  		// good
    61  		{`{"jsonrpc": "2.0", "method": "c", "id": "0", "params": null}`, "", types.JSONRPCStringID("0")},
    62  		{`{"method": "c", "id": "0", "params": {}}`, "", types.JSONRPCStringID("0")},
    63  		{`{"method": "c", "id": "0", "params": ["a", "10"]}`, "", types.JSONRPCStringID("0")},
    64  	}
    65  
    66  	for i, tt := range tests {
    67  		req, _ := http.NewRequest("POST", "http://localhost/", strings.NewReader(tt.payload))
    68  		rec := httptest.NewRecorder()
    69  		mux.ServeHTTP(rec, req)
    70  		res := rec.Result()
    71  		// Always expecting back a JSONRPCResponse
    72  		assert.True(t, statusOK(res.StatusCode), "#%d: should always return 2XX", i)
    73  		blob, err := io.ReadAll(res.Body)
    74  		if err != nil {
    75  			t.Errorf("#%d: err reading body: %v", i, err)
    76  			continue
    77  		}
    78  
    79  		recv := new(types.RPCResponse)
    80  		assert.Nil(t, json.Unmarshal(blob, recv), "#%d: expecting successful parsing of an RPCResponse:\nblob: %s", i, blob)
    81  		assert.NotEqual(t, recv, new(types.RPCResponse), "#%d: not expecting a blank RPCResponse", i)
    82  		assert.Equal(t, tt.expectedId, recv.ID, "#%d: expected ID not matched in RPCResponse", i)
    83  		if tt.wantErr == "" {
    84  			assert.Nil(t, recv.Error, "#%d: not expecting an error", i)
    85  		} else {
    86  			assert.True(t, recv.Error.Code < 0, "#%d: not expecting a positive JSONRPC code", i)
    87  			// The wanted error is either in the message or the data
    88  			assert.Contains(t, recv.Error.Message+recv.Error.Data, tt.wantErr, "#%d: expected substring", i)
    89  		}
    90  	}
    91  }
    92  
    93  func TestJSONRPCID(t *testing.T) {
    94  	t.Parallel()
    95  
    96  	mux := testMux()
    97  	tests := []struct {
    98  		payload    string
    99  		wantErr    bool
   100  		expectedId interface{}
   101  	}{
   102  		// good id
   103  		{`{"jsonrpc": "2.0", "method": "c", "id": "0", "params": ["a", "10"]}`, false, types.JSONRPCStringID("0")},
   104  		{`{"jsonrpc": "2.0", "method": "c", "id": "abc", "params": ["a", "10"]}`, false, types.JSONRPCStringID("abc")},
   105  		{`{"jsonrpc": "2.0", "method": "c", "id": 0, "params": ["a", "10"]}`, false, types.JSONRPCIntID(0)},
   106  		{`{"jsonrpc": "2.0", "method": "c", "id": 1, "params": ["a", "10"]}`, false, types.JSONRPCIntID(1)},
   107  		{`{"jsonrpc": "2.0", "method": "c", "id": 1.3, "params": ["a", "10"]}`, false, types.JSONRPCIntID(1)},
   108  		{`{"jsonrpc": "2.0", "method": "c", "id": -1, "params": ["a", "10"]}`, false, types.JSONRPCIntID(-1)},
   109  		{`{"jsonrpc": "2.0", "method": "c", "id": null, "params": ["a", "10"]}`, false, nil},
   110  
   111  		// bad id
   112  		{`{"jsonrpc": "2.0", "method": "c", "id": {}, "params": ["a", "10"]}`, true, nil},
   113  		{`{"jsonrpc": "2.0", "method": "c", "id": [], "params": ["a", "10"]}`, true, nil},
   114  	}
   115  
   116  	for i, tt := range tests {
   117  		req, _ := http.NewRequest("POST", "http://localhost/", strings.NewReader(tt.payload))
   118  		rec := httptest.NewRecorder()
   119  		mux.ServeHTTP(rec, req)
   120  		res := rec.Result()
   121  		// Always expecting back a JSONRPCResponse
   122  		assert.True(t, statusOK(res.StatusCode), "#%d: should always return 2XX", i)
   123  		blob, err := io.ReadAll(res.Body)
   124  		if err != nil {
   125  			t.Errorf("#%d: err reading body: %v", i, err)
   126  			continue
   127  		}
   128  
   129  		recv := new(types.RPCResponse)
   130  		err = json.Unmarshal(blob, recv)
   131  		assert.Nil(t, err, "#%d: expecting successful parsing of an RPCResponse:\nblob: %s", i, blob)
   132  		if !tt.wantErr {
   133  			assert.NotEqual(t, recv, new(types.RPCResponse), "#%d: not expecting a blank RPCResponse", i)
   134  			assert.Equal(t, tt.expectedId, recv.ID, "#%d: expected ID not matched in RPCResponse", i)
   135  			assert.Nil(t, recv.Error, "#%d: not expecting an error", i)
   136  		} else {
   137  			assert.True(t, recv.Error.Code < 0, "#%d: not expecting a positive JSONRPC code", i)
   138  		}
   139  	}
   140  }
   141  
   142  func TestRPCNotification(t *testing.T) {
   143  	t.Parallel()
   144  
   145  	mux := testMux()
   146  	body := strings.NewReader(`{"jsonrpc": "2.0", "id": ""}`)
   147  	req, _ := http.NewRequest("POST", "http://localhost/", body)
   148  	rec := httptest.NewRecorder()
   149  	mux.ServeHTTP(rec, req)
   150  	res := rec.Result()
   151  
   152  	// Always expecting back a JSONRPCResponse
   153  	require.True(t, statusOK(res.StatusCode), "should always return 2XX")
   154  	blob, err := io.ReadAll(res.Body)
   155  	require.Nil(t, err, "reading from the body should not give back an error")
   156  	require.Equal(t, len(blob), 0, "a notification SHOULD NOT be responded to by the server")
   157  }
   158  
   159  func TestRPCNotificationInBatch(t *testing.T) {
   160  	t.Parallel()
   161  
   162  	mux := testMux()
   163  	tests := []struct {
   164  		payload     string
   165  		expectCount int
   166  	}{
   167  		{
   168  			`[
   169  				{"jsonrpc": "2.0","id": ""},
   170  				{"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}
   171  			 ]`,
   172  			1,
   173  		},
   174  		{
   175  			`[
   176  				{"jsonrpc": "2.0","id": ""},
   177  				{"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]},
   178  				{"jsonrpc": "2.0","id": ""},
   179  				{"jsonrpc": "2.0","method":"c","id":"abc","params":["a","10"]}
   180  			 ]`,
   181  			2,
   182  		},
   183  	}
   184  	for i, tt := range tests {
   185  		req, _ := http.NewRequest("POST", "http://localhost/", strings.NewReader(tt.payload))
   186  		rec := httptest.NewRecorder()
   187  		mux.ServeHTTP(rec, req)
   188  		res := rec.Result()
   189  		// Always expecting back a JSONRPCResponse
   190  		assert.True(t, statusOK(res.StatusCode), "#%d: should always return 2XX", i)
   191  		blob, err := io.ReadAll(res.Body)
   192  		if err != nil {
   193  			t.Errorf("#%d: err reading body: %v", i, err)
   194  			continue
   195  		}
   196  
   197  		var responses types.RPCResponses
   198  		// try to unmarshal an array first
   199  		err = json.Unmarshal(blob, &responses)
   200  		if err != nil {
   201  			// if we were actually expecting an array, but got an error
   202  			if tt.expectCount > 1 {
   203  				t.Errorf("#%d: expected an array, couldn't unmarshal it\nblob: %s", i, blob)
   204  				continue
   205  			} else {
   206  				// we were expecting an error here, so let's unmarshal a single response
   207  				var response types.RPCResponse
   208  				err = json.Unmarshal(blob, &response)
   209  				if err != nil {
   210  					t.Errorf("#%d: expected successful parsing of an RPCResponse\nblob: %s", i, blob)
   211  					continue
   212  				}
   213  				// have a single-element result
   214  				responses = types.RPCResponses{response}
   215  			}
   216  		}
   217  		if tt.expectCount != len(responses) {
   218  			t.Errorf("#%d: expected %d response(s), but got %d\nblob: %s", i, tt.expectCount, len(responses), blob)
   219  			continue
   220  		}
   221  		for _, response := range responses {
   222  			assert.NotEqual(t, response, new(types.RPCResponse), "#%d: not expecting a blank RPCResponse", i)
   223  		}
   224  	}
   225  }
   226  
   227  func TestUnknownRPCPath(t *testing.T) {
   228  	t.Parallel()
   229  
   230  	mux := testMux()
   231  	req, _ := http.NewRequest("GET", "http://localhost/unknownrpcpath", nil)
   232  	rec := httptest.NewRecorder()
   233  	mux.ServeHTTP(rec, req)
   234  	res := rec.Result()
   235  
   236  	// Always expecting back a 404 error
   237  	require.Equal(t, http.StatusNotFound, res.StatusCode, "should always return 404")
   238  }
   239  
   240  // -----------
   241  // JSON-RPC over WEBSOCKETS
   242  
   243  func TestWebsocketManagerHandler(t *testing.T) {
   244  	t.Parallel()
   245  
   246  	s := newWSServer()
   247  	defer s.Close()
   248  
   249  	// check upgrader works
   250  	d := websocket.Dialer{}
   251  	c, dialResp, err := d.Dial("ws://"+s.Listener.Addr().String()+"/websocket", nil)
   252  	require.NoError(t, err)
   253  
   254  	if got, want := dialResp.StatusCode, http.StatusSwitchingProtocols; got != want {
   255  		t.Errorf("dialResp.StatusCode = %q, want %q", got, want)
   256  	}
   257  
   258  	// check basic functionality works
   259  	req, err := types.MapToRequest(types.JSONRPCStringID("TestWebsocketManager"), "c", map[string]interface{}{"s": "a", "i": 10})
   260  	require.NoError(t, err)
   261  	err = c.WriteJSON(req)
   262  	require.NoError(t, err)
   263  
   264  	var resp types.RPCResponse
   265  	err = c.ReadJSON(&resp)
   266  	require.NoError(t, err)
   267  	require.Nil(t, resp.Error)
   268  }
   269  
   270  func newWSServer() *httptest.Server {
   271  	funcMap := map[string]*rs.RPCFunc{
   272  		"c": rs.NewWSRPCFunc(func(ctx *types.Context, s string, i int) (string, error) { return "foo", nil }, "s,i"),
   273  	}
   274  	wm := rs.NewWebsocketManager(funcMap)
   275  	wm.SetLogger(log.NewNoopLogger())
   276  
   277  	mux := http.NewServeMux()
   278  	mux.HandleFunc("/websocket", wm.WebsocketHandler)
   279  
   280  	return httptest.NewServer(mux)
   281  }