github.com/vipernet-xyz/tm@v0.34.24/rpc/jsonrpc/jsonrpc_test.go (about)

     1  package jsonrpc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	crand "crypto/rand"
     7  	"encoding/json"
     8  	"fmt"
     9  	"net/http"
    10  	"net/url"
    11  	"os"
    12  	"os/exec"
    13  	"strings"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/go-kit/log/term"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  
    21  	tmbytes "github.com/vipernet-xyz/tm/libs/bytes"
    22  	"github.com/vipernet-xyz/tm/libs/log"
    23  	tmrand "github.com/vipernet-xyz/tm/libs/rand"
    24  
    25  	client "github.com/vipernet-xyz/tm/rpc/jsonrpc/client"
    26  	server "github.com/vipernet-xyz/tm/rpc/jsonrpc/server"
    27  	types "github.com/vipernet-xyz/tm/rpc/jsonrpc/types"
    28  )
    29  
    30  // Client and Server should work over tcp or unix sockets
    31  const (
    32  	tcpAddr = "tcp://127.0.0.1:47768"
    33  
    34  	unixSocket = "/tmp/rpc_test.sock"
    35  	unixAddr   = "unix://" + unixSocket
    36  
    37  	websocketEndpoint = "/websocket/endpoint"
    38  
    39  	testVal = "acbd"
    40  )
    41  
    42  var ctx = context.Background()
    43  
    44  type ResultEcho struct {
    45  	Value string `json:"value"`
    46  }
    47  
    48  type ResultEchoInt struct {
    49  	Value int `json:"value"`
    50  }
    51  
    52  type ResultEchoBytes struct {
    53  	Value []byte `json:"value"`
    54  }
    55  
    56  type ResultEchoDataBytes struct {
    57  	Value tmbytes.HexBytes `json:"value"`
    58  }
    59  
    60  type ResultEchoWithDefault struct {
    61  	Value int `json:"value"`
    62  }
    63  
    64  // Define some routes
    65  var Routes = map[string]*server.RPCFunc{
    66  	"echo":            server.NewRPCFunc(EchoResult, "arg"),
    67  	"echo_ws":         server.NewWSRPCFunc(EchoWSResult, "arg"),
    68  	"echo_bytes":      server.NewRPCFunc(EchoBytesResult, "arg"),
    69  	"echo_data_bytes": server.NewRPCFunc(EchoDataBytesResult, "arg"),
    70  	"echo_int":        server.NewRPCFunc(EchoIntResult, "arg"),
    71  	"echo_default":    server.NewRPCFunc(EchoWithDefault, "arg", server.Cacheable("arg")),
    72  }
    73  
    74  func EchoResult(ctx *types.Context, v string) (*ResultEcho, error) {
    75  	return &ResultEcho{v}, nil
    76  }
    77  
    78  func EchoWSResult(ctx *types.Context, v string) (*ResultEcho, error) {
    79  	return &ResultEcho{v}, nil
    80  }
    81  
    82  func EchoIntResult(ctx *types.Context, v int) (*ResultEchoInt, error) {
    83  	return &ResultEchoInt{v}, nil
    84  }
    85  
    86  func EchoBytesResult(ctx *types.Context, v []byte) (*ResultEchoBytes, error) {
    87  	return &ResultEchoBytes{v}, nil
    88  }
    89  
    90  func EchoDataBytesResult(ctx *types.Context, v tmbytes.HexBytes) (*ResultEchoDataBytes, error) {
    91  	return &ResultEchoDataBytes{v}, nil
    92  }
    93  
    94  func EchoWithDefault(ctx *types.Context, v *int) (*ResultEchoWithDefault, error) {
    95  	val := -1
    96  	if v != nil {
    97  		val = *v
    98  	}
    99  	return &ResultEchoWithDefault{val}, nil
   100  }
   101  
   102  func TestMain(m *testing.M) {
   103  	setup()
   104  	code := m.Run()
   105  	os.Exit(code)
   106  }
   107  
   108  var colorFn = func(keyvals ...interface{}) term.FgBgColor {
   109  	for i := 0; i < len(keyvals)-1; i += 2 {
   110  		if keyvals[i] == "socket" {
   111  			if keyvals[i+1] == "tcp" {
   112  				return term.FgBgColor{Fg: term.DarkBlue}
   113  			} else if keyvals[i+1] == "unix" {
   114  				return term.FgBgColor{Fg: term.DarkCyan}
   115  			}
   116  		}
   117  	}
   118  	return term.FgBgColor{}
   119  }
   120  
   121  // launch unix and tcp servers
   122  func setup() {
   123  	logger := log.NewTMLoggerWithColorFn(log.NewSyncWriter(os.Stdout), colorFn)
   124  
   125  	cmd := exec.Command("rm", "-f", unixSocket)
   126  	err := cmd.Start()
   127  	if err != nil {
   128  		panic(err)
   129  	}
   130  	if err = cmd.Wait(); err != nil {
   131  		panic(err)
   132  	}
   133  
   134  	tcpLogger := logger.With("socket", "tcp")
   135  	mux := http.NewServeMux()
   136  	server.RegisterRPCFuncs(mux, Routes, tcpLogger)
   137  	wm := server.NewWebsocketManager(Routes, server.ReadWait(5*time.Second), server.PingPeriod(1*time.Second))
   138  	wm.SetLogger(tcpLogger)
   139  	mux.HandleFunc(websocketEndpoint, wm.WebsocketHandler)
   140  	config := server.DefaultConfig()
   141  	listener1, err := server.Listen(tcpAddr, config)
   142  	if err != nil {
   143  		panic(err)
   144  	}
   145  	go func() {
   146  		if err := server.Serve(listener1, mux, tcpLogger, config); err != nil {
   147  			panic(err)
   148  		}
   149  	}()
   150  
   151  	unixLogger := logger.With("socket", "unix")
   152  	mux2 := http.NewServeMux()
   153  	server.RegisterRPCFuncs(mux2, Routes, unixLogger)
   154  	wm = server.NewWebsocketManager(Routes)
   155  	wm.SetLogger(unixLogger)
   156  	mux2.HandleFunc(websocketEndpoint, wm.WebsocketHandler)
   157  	listener2, err := server.Listen(unixAddr, config)
   158  	if err != nil {
   159  		panic(err)
   160  	}
   161  	go func() {
   162  		if err := server.Serve(listener2, mux2, unixLogger, config); err != nil {
   163  			panic(err)
   164  		}
   165  	}()
   166  
   167  	// wait for servers to start
   168  	time.Sleep(time.Second * 2)
   169  }
   170  
   171  func echoViaHTTP(cl client.Caller, val string) (string, error) {
   172  	params := map[string]interface{}{
   173  		"arg": val,
   174  	}
   175  	result := new(ResultEcho)
   176  	if _, err := cl.Call(ctx, "echo", params, result); err != nil {
   177  		return "", err
   178  	}
   179  	return result.Value, nil
   180  }
   181  
   182  func echoIntViaHTTP(cl client.Caller, val int) (int, error) {
   183  	params := map[string]interface{}{
   184  		"arg": val,
   185  	}
   186  	result := new(ResultEchoInt)
   187  	if _, err := cl.Call(ctx, "echo_int", params, result); err != nil {
   188  		return 0, err
   189  	}
   190  	return result.Value, nil
   191  }
   192  
   193  func echoBytesViaHTTP(cl client.Caller, bytes []byte) ([]byte, error) {
   194  	params := map[string]interface{}{
   195  		"arg": bytes,
   196  	}
   197  	result := new(ResultEchoBytes)
   198  	if _, err := cl.Call(ctx, "echo_bytes", params, result); err != nil {
   199  		return []byte{}, err
   200  	}
   201  	return result.Value, nil
   202  }
   203  
   204  func echoDataBytesViaHTTP(cl client.Caller, bytes tmbytes.HexBytes) (tmbytes.HexBytes, error) {
   205  	params := map[string]interface{}{
   206  		"arg": bytes,
   207  	}
   208  	result := new(ResultEchoDataBytes)
   209  	if _, err := cl.Call(ctx, "echo_data_bytes", params, result); err != nil {
   210  		return []byte{}, err
   211  	}
   212  	return result.Value, nil
   213  }
   214  
   215  func echoWithDefaultViaHTTP(cl client.Caller, v *int) (int, error) {
   216  	params := map[string]interface{}{}
   217  	if v != nil {
   218  		params["arg"] = *v
   219  	}
   220  	result := new(ResultEchoWithDefault)
   221  	if _, err := cl.Call(ctx, "echo_default", params, result); err != nil {
   222  		return 0, err
   223  	}
   224  	return result.Value, nil
   225  }
   226  
   227  func testWithHTTPClient(t *testing.T, cl client.HTTPClient) {
   228  	val := testVal
   229  	got, err := echoViaHTTP(cl, val)
   230  	require.NoError(t, err)
   231  	assert.Equal(t, got, val)
   232  
   233  	val2 := randBytes(t)
   234  	got2, err := echoBytesViaHTTP(cl, val2)
   235  	require.NoError(t, err)
   236  	assert.Equal(t, got2, val2)
   237  
   238  	val3 := tmbytes.HexBytes(randBytes(t))
   239  	got3, err := echoDataBytesViaHTTP(cl, val3)
   240  	require.NoError(t, err)
   241  	assert.Equal(t, got3, val3)
   242  
   243  	val4 := tmrand.Intn(10000)
   244  	got4, err := echoIntViaHTTP(cl, val4)
   245  	require.NoError(t, err)
   246  	assert.Equal(t, got4, val4)
   247  
   248  	got5, err := echoWithDefaultViaHTTP(cl, nil)
   249  	require.NoError(t, err)
   250  	assert.Equal(t, got5, -1)
   251  
   252  	val6 := tmrand.Intn(10000)
   253  	got6, err := echoWithDefaultViaHTTP(cl, &val6)
   254  	require.NoError(t, err)
   255  	assert.Equal(t, got6, val6)
   256  }
   257  
   258  func echoViaWS(cl *client.WSClient, val string) (string, error) {
   259  	params := map[string]interface{}{
   260  		"arg": val,
   261  	}
   262  	err := cl.Call(context.Background(), "echo", params)
   263  	if err != nil {
   264  		return "", err
   265  	}
   266  
   267  	msg := <-cl.ResponsesCh
   268  	if msg.Error != nil {
   269  		return "", err
   270  	}
   271  	result := new(ResultEcho)
   272  	err = json.Unmarshal(msg.Result, result)
   273  	if err != nil {
   274  		return "", nil
   275  	}
   276  	return result.Value, nil
   277  }
   278  
   279  func echoBytesViaWS(cl *client.WSClient, bytes []byte) ([]byte, error) {
   280  	params := map[string]interface{}{
   281  		"arg": bytes,
   282  	}
   283  	err := cl.Call(context.Background(), "echo_bytes", params)
   284  	if err != nil {
   285  		return []byte{}, err
   286  	}
   287  
   288  	msg := <-cl.ResponsesCh
   289  	if msg.Error != nil {
   290  		return []byte{}, msg.Error
   291  	}
   292  	result := new(ResultEchoBytes)
   293  	err = json.Unmarshal(msg.Result, result)
   294  	if err != nil {
   295  		return []byte{}, nil
   296  	}
   297  	return result.Value, nil
   298  }
   299  
   300  func testWithWSClient(t *testing.T, cl *client.WSClient) {
   301  	val := testVal
   302  	got, err := echoViaWS(cl, val)
   303  	require.Nil(t, err)
   304  	assert.Equal(t, got, val)
   305  
   306  	val2 := randBytes(t)
   307  	got2, err := echoBytesViaWS(cl, val2)
   308  	require.Nil(t, err)
   309  	assert.Equal(t, got2, val2)
   310  }
   311  
   312  //-------------
   313  
   314  func TestServersAndClientsBasic(t *testing.T) {
   315  	serverAddrs := [...]string{tcpAddr, unixAddr}
   316  	for _, addr := range serverAddrs {
   317  		cl1, err := client.NewURI(addr)
   318  		require.Nil(t, err)
   319  		fmt.Printf("=== testing server on %s using URI client", addr)
   320  		testWithHTTPClient(t, cl1)
   321  
   322  		cl2, err := client.New(addr)
   323  		require.Nil(t, err)
   324  		fmt.Printf("=== testing server on %s using JSONRPC client", addr)
   325  		testWithHTTPClient(t, cl2)
   326  
   327  		cl3, err := client.NewWS(addr, websocketEndpoint)
   328  		require.Nil(t, err)
   329  		cl3.SetLogger(log.TestingLogger())
   330  		err = cl3.Start()
   331  		require.Nil(t, err)
   332  		fmt.Printf("=== testing server on %s using WS client", addr)
   333  		testWithWSClient(t, cl3)
   334  		err = cl3.Stop()
   335  		require.NoError(t, err)
   336  	}
   337  }
   338  
   339  func TestHexStringArg(t *testing.T) {
   340  	cl, err := client.NewURI(tcpAddr)
   341  	require.Nil(t, err)
   342  	// should NOT be handled as hex
   343  	val := "0xabc"
   344  	got, err := echoViaHTTP(cl, val)
   345  	require.Nil(t, err)
   346  	assert.Equal(t, got, val)
   347  }
   348  
   349  func TestQuotedStringArg(t *testing.T) {
   350  	cl, err := client.NewURI(tcpAddr)
   351  	require.Nil(t, err)
   352  	// should NOT be unquoted
   353  	val := "\"abc\""
   354  	got, err := echoViaHTTP(cl, val)
   355  	require.Nil(t, err)
   356  	assert.Equal(t, got, val)
   357  }
   358  
   359  func TestWSNewWSRPCFunc(t *testing.T) {
   360  	cl, err := client.NewWS(tcpAddr, websocketEndpoint)
   361  	require.Nil(t, err)
   362  	cl.SetLogger(log.TestingLogger())
   363  	err = cl.Start()
   364  	require.Nil(t, err)
   365  	t.Cleanup(func() {
   366  		if err := cl.Stop(); err != nil {
   367  			t.Error(err)
   368  		}
   369  	})
   370  
   371  	val := testVal
   372  	params := map[string]interface{}{
   373  		"arg": val,
   374  	}
   375  	err = cl.Call(context.Background(), "echo_ws", params)
   376  	require.Nil(t, err)
   377  
   378  	msg := <-cl.ResponsesCh
   379  	if msg.Error != nil {
   380  		t.Fatal(err)
   381  	}
   382  	result := new(ResultEcho)
   383  	err = json.Unmarshal(msg.Result, result)
   384  	require.Nil(t, err)
   385  	got := result.Value
   386  	assert.Equal(t, got, val)
   387  }
   388  
   389  func TestWSHandlesArrayParams(t *testing.T) {
   390  	cl, err := client.NewWS(tcpAddr, websocketEndpoint)
   391  	require.Nil(t, err)
   392  	cl.SetLogger(log.TestingLogger())
   393  	err = cl.Start()
   394  	require.Nil(t, err)
   395  	t.Cleanup(func() {
   396  		if err := cl.Stop(); err != nil {
   397  			t.Error(err)
   398  		}
   399  	})
   400  
   401  	val := testVal
   402  	params := []interface{}{val}
   403  	err = cl.CallWithArrayParams(context.Background(), "echo_ws", params)
   404  	require.Nil(t, err)
   405  
   406  	msg := <-cl.ResponsesCh
   407  	if msg.Error != nil {
   408  		t.Fatalf("%+v", err)
   409  	}
   410  	result := new(ResultEcho)
   411  	err = json.Unmarshal(msg.Result, result)
   412  	require.Nil(t, err)
   413  	got := result.Value
   414  	assert.Equal(t, got, val)
   415  }
   416  
   417  // TestWSClientPingPong checks that a client & server exchange pings
   418  // & pongs so connection stays alive.
   419  func TestWSClientPingPong(t *testing.T) {
   420  	cl, err := client.NewWS(tcpAddr, websocketEndpoint)
   421  	require.Nil(t, err)
   422  	cl.SetLogger(log.TestingLogger())
   423  	err = cl.Start()
   424  	require.Nil(t, err)
   425  	t.Cleanup(func() {
   426  		if err := cl.Stop(); err != nil {
   427  			t.Error(err)
   428  		}
   429  	})
   430  
   431  	time.Sleep(6 * time.Second)
   432  }
   433  
   434  func TestJSONRPCCaching(t *testing.T) {
   435  	httpAddr := strings.Replace(tcpAddr, "tcp://", "http://", 1)
   436  	cl, err := client.DefaultHTTPClient(httpAddr)
   437  	require.NoError(t, err)
   438  
   439  	// Not supplying the arg should result in not caching
   440  	params := make(map[string]interface{})
   441  	req, err := types.MapToRequest(types.JSONRPCIntID(1000), "echo_default", params)
   442  	require.NoError(t, err)
   443  
   444  	res1, err := rawJSONRPCRequest(t, cl, httpAddr, req)
   445  	defer func() { _ = res1.Body.Close() }()
   446  	require.NoError(t, err)
   447  	assert.Equal(t, "", res1.Header.Get("Cache-control"))
   448  
   449  	// Supplying the arg should result in caching
   450  	params["arg"] = tmrand.Intn(10000)
   451  	req, err = types.MapToRequest(types.JSONRPCIntID(1001), "echo_default", params)
   452  	require.NoError(t, err)
   453  
   454  	res2, err := rawJSONRPCRequest(t, cl, httpAddr, req)
   455  	defer func() { _ = res2.Body.Close() }()
   456  	require.NoError(t, err)
   457  	assert.Equal(t, "public, max-age=86400", res2.Header.Get("Cache-control"))
   458  }
   459  
   460  func rawJSONRPCRequest(t *testing.T, cl *http.Client, url string, req interface{}) (*http.Response, error) {
   461  	reqBytes, err := json.Marshal(req)
   462  	require.NoError(t, err)
   463  
   464  	reqBuf := bytes.NewBuffer(reqBytes)
   465  	httpReq, err := http.NewRequest(http.MethodPost, url, reqBuf)
   466  	require.NoError(t, err)
   467  
   468  	httpReq.Header.Set("Content-type", "application/json")
   469  
   470  	return cl.Do(httpReq)
   471  }
   472  
   473  func TestURICaching(t *testing.T) {
   474  	httpAddr := strings.Replace(tcpAddr, "tcp://", "http://", 1)
   475  	cl, err := client.DefaultHTTPClient(httpAddr)
   476  	require.NoError(t, err)
   477  
   478  	// Not supplying the arg should result in not caching
   479  	args := url.Values{}
   480  	res1, err := rawURIRequest(t, cl, httpAddr+"/echo_default", args)
   481  	defer func() { _ = res1.Body.Close() }()
   482  	require.NoError(t, err)
   483  	assert.Equal(t, "", res1.Header.Get("Cache-control"))
   484  
   485  	// Supplying the arg should result in caching
   486  	args.Set("arg", fmt.Sprintf("%d", tmrand.Intn(10000)))
   487  	res2, err := rawURIRequest(t, cl, httpAddr+"/echo_default", args)
   488  	defer func() { _ = res2.Body.Close() }()
   489  	require.NoError(t, err)
   490  	assert.Equal(t, "public, max-age=86400", res2.Header.Get("Cache-control"))
   491  }
   492  
   493  func rawURIRequest(t *testing.T, cl *http.Client, url string, args url.Values) (*http.Response, error) {
   494  	req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(args.Encode()))
   495  	require.NoError(t, err)
   496  
   497  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   498  
   499  	return cl.Do(req)
   500  }
   501  
   502  func randBytes(t *testing.T) []byte {
   503  	n := tmrand.Intn(10) + 2
   504  	buf := make([]byte, n)
   505  	_, err := crand.Read(buf)
   506  	require.Nil(t, err)
   507  	return bytes.ReplaceAll(buf, []byte("="), []byte{100})
   508  }