github.com/klaytn/klaytn@v1.10.2/networks/rpc/websocket_test.go (about)

     1  // Modifications Copyright 2020 The klaytn Authors
     2  // Copyright 2018 The go-ethereum Authors
     3  // This file is part of the go-ethereum library.
     4  //
     5  // The go-ethereum library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Lesser General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // The go-ethereum library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    13  // GNU Lesser General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public License
    16  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    17  //
    18  // This file is derived from rpc/websocket_test.go (2020/04/03).
    19  // Modified and improved for the klaytn development.
    20  
    21  package rpc
    22  
    23  import (
    24  	"context"
    25  	"encoding/base64"
    26  	"net"
    27  	"net/http"
    28  	"net/http/httptest"
    29  	"reflect"
    30  	"strings"
    31  	"testing"
    32  	"time"
    33  
    34  	"github.com/gorilla/websocket"
    35  	"github.com/klaytn/klaytn/common"
    36  	"github.com/stretchr/testify/assert"
    37  )
    38  
    39  type echoArgs struct {
    40  	S string
    41  }
    42  
    43  type echoResult struct {
    44  	String string
    45  	Int    int
    46  	Args   *echoArgs
    47  }
    48  
    49  func TestWebsocketLargeCall(t *testing.T) {
    50  	t.Parallel()
    51  
    52  	// create server
    53  	var (
    54  		srv     = newTestServer("service", new(Service))
    55  		httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
    56  		wsAddr  = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
    57  	)
    58  	defer srv.Stop()
    59  	defer httpsrv.Close()
    60  	time.Sleep(100 * time.Millisecond)
    61  
    62  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
    63  	defer cancel()
    64  	client, err := DialWebsocket(ctx, wsAddr, "")
    65  	if err != nil {
    66  		t.Fatalf("can't dial: %v", err)
    67  	}
    68  	defer client.Close()
    69  
    70  	// set configurations before testing
    71  	var result echoResult
    72  	method := "service_echo"
    73  
    74  	// This call sends slightly less than the limit and should work.
    75  	arg := strings.Repeat("x", common.MaxRequestContentLength-200)
    76  	assert.NoError(t, client.Call(&result, method, arg, 1), "valid call didn't work")
    77  	assert.Equal(t, arg, result.String, "wrong string echoed")
    78  
    79  	// This call sends slightly larger than the allowed size and shouldn't work.
    80  	arg = strings.Repeat("x", common.MaxRequestContentLength)
    81  	assert.Error(t, client.Call(&result, method, arg, 1), "no error for too large call")
    82  }
    83  
    84  func newTestListener() net.Listener {
    85  	ln, err := net.Listen("tcp", "localhost:0")
    86  	if err != nil {
    87  		panic(err)
    88  	}
    89  	return ln
    90  }
    91  
    92  /*
    93  func TestWSServer_MaxConnections(t *testing.T) {
    94  	// create server
    95  	var (
    96  		srv = newTestServer("service", new(Service))
    97  		ln  = newTestListener()
    98  	)
    99  	defer srv.Stop()
   100  	defer ln.Close()
   101  
   102  	go NewWSServer([]string{"*"}, srv).Serve(ln)
   103  	time.Sleep(100 * time.Millisecond)
   104  
   105  	// set max websocket connections
   106  	MaxWebsocketConnections = 3
   107  	testWebsocketMaxConnections(t, "ws://"+ln.Addr().String(), int(MaxWebsocketConnections))
   108  }
   109  */
   110  
   111  func TestFastWSServer_MaxConnections(t *testing.T) {
   112  	// create server
   113  	var (
   114  		srv = newTestServer("service", new(Service))
   115  		ln  = newTestListener()
   116  	)
   117  	defer srv.Stop()
   118  	defer ln.Close()
   119  
   120  	go NewFastWSServer([]string{"*"}, srv).Serve(ln)
   121  	time.Sleep(100 * time.Millisecond)
   122  
   123  	// set max websocket connections
   124  	MaxWebsocketConnections = 3
   125  	testWebsocketMaxConnections(t, "ws://"+ln.Addr().String(), int(MaxWebsocketConnections))
   126  }
   127  
   128  func testWebsocketMaxConnections(t *testing.T, addr string, maxConnections int) {
   129  	var closers []*Client
   130  
   131  	for i := 0; i <= maxConnections; i++ {
   132  		client, err := DialWebsocket(context.Background(), addr, "")
   133  		if err != nil {
   134  			t.Fatal(err)
   135  		}
   136  		closers = append(closers, client)
   137  
   138  		var result echoResult
   139  		method := "service_echo"
   140  		arg := strings.Repeat("x", i)
   141  		err = client.Call(&result, method, arg, 1)
   142  		if i < int(MaxWebsocketConnections) {
   143  			assert.NoError(t, err)
   144  			assert.Equal(t, arg, result.String, "wrong string echoed")
   145  		} else {
   146  			assert.Error(t, err)
   147  			// assert.Equal(t, "EOF", err.Error())
   148  		}
   149  	}
   150  
   151  	for _, client := range closers {
   152  		client.Close()
   153  	}
   154  }
   155  
   156  func TestWebsocketClientHeaders(t *testing.T) {
   157  	t.Parallel()
   158  
   159  	endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com")
   160  	if err != nil {
   161  		t.Fatalf("wsGetConfig failed: %s", err)
   162  	}
   163  	if endpoint != "wss://example.com:1234" {
   164  		t.Fatal("User should have been stripped from the URL")
   165  	}
   166  	if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" {
   167  		t.Fatal("Basic auth header is incorrect")
   168  	}
   169  	if header.Get("origin") != "https://example.com" {
   170  		t.Fatal("Origin not set")
   171  	}
   172  }
   173  
   174  // This test checks that the server rejects connections from disallowed origins.
   175  func TestWebsocketOriginCheck(t *testing.T) {
   176  	t.Parallel()
   177  
   178  	var (
   179  		srv     = newTestServer("service", new(Service))
   180  		httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}))
   181  		wsURL   = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
   182  	)
   183  	defer srv.Stop()
   184  	defer httpsrv.Close()
   185  
   186  	client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com")
   187  	if err == nil {
   188  		client.Close()
   189  		t.Fatal("no error for wrong origin")
   190  	}
   191  	wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"}
   192  	if !reflect.DeepEqual(err, wantErr) {
   193  		t.Fatalf("wrong error for wrong origin: %q", err)
   194  	}
   195  }
   196  
   197  func TestClientWebsocketPing(t *testing.T) {
   198  	t.Parallel()
   199  
   200  	var (
   201  		sendPing    = make(chan struct{})
   202  		server      = wsPingTestServer(t, sendPing)
   203  		ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
   204  	)
   205  	defer cancel()
   206  	defer server.Shutdown(ctx)
   207  
   208  	client, err := DialContext(ctx, "ws://"+server.Addr)
   209  	if err != nil {
   210  		t.Fatalf("client dial error: %v", err)
   211  	}
   212  	resultChan := make(chan int)
   213  	sub, err := client.KlaySubscribe(ctx, resultChan, "foo")
   214  	if err != nil {
   215  		t.Fatalf("client subscribe error: %v", err)
   216  	}
   217  
   218  	// Wait for the context's deadline to be reached before proceeding.
   219  	// This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798
   220  	<-ctx.Done()
   221  	close(sendPing)
   222  
   223  	// Wait for the subscription result.
   224  	timeout := time.NewTimer(5 * time.Second)
   225  	for {
   226  		select {
   227  		case err := <-sub.Err():
   228  			t.Error("client subscription error:", err)
   229  		case result := <-resultChan:
   230  			t.Log("client got result:", result)
   231  			return
   232  		case <-timeout.C:
   233  			t.Error("didn't get any result within the test timeout")
   234  			return
   235  		}
   236  	}
   237  }
   238  
   239  // wsPingTestServer runs a WebSocket server which accepts a single subscription request.
   240  // When a value arrives on sendPing, the server sends a ping frame, waits for a matching
   241  // pong and finally delivers a single subscription result.
   242  func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server {
   243  	var srv http.Server
   244  	shutdown := make(chan struct{})
   245  	srv.RegisterOnShutdown(func() {
   246  		close(shutdown)
   247  	})
   248  	srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   249  		// Upgrade to WebSocket.
   250  		upgrader := websocket.Upgrader{
   251  			CheckOrigin: func(r *http.Request) bool { return true },
   252  		}
   253  		conn, err := upgrader.Upgrade(w, r, nil)
   254  		if err != nil {
   255  			t.Errorf("server WS upgrade error: %v", err)
   256  			return
   257  		}
   258  		defer conn.Close()
   259  
   260  		// Handle the connection.
   261  		wsPingTestHandler(t, conn, shutdown, sendPing)
   262  	})
   263  
   264  	// Start the server.
   265  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   266  	if err != nil {
   267  		t.Fatal("can't listen:", err)
   268  	}
   269  	srv.Addr = listener.Addr().String()
   270  	go srv.Serve(listener)
   271  	return &srv
   272  }
   273  
   274  func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) {
   275  	// Canned responses for the eth_subscribe call in TestClientWebsocketPing.
   276  	const (
   277  		subResp   = `{"jsonrpc":"2.0","id":1,"result":"0x00"}`
   278  		subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}`
   279  	)
   280  
   281  	// Handle subscribe request.
   282  	if _, _, err := conn.ReadMessage(); err != nil {
   283  		t.Errorf("server read error: %v", err)
   284  		return
   285  	}
   286  	if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil {
   287  		t.Errorf("server write error: %v", err)
   288  		return
   289  	}
   290  
   291  	// Read from the connection to process control messages.
   292  	pongCh := make(chan string)
   293  	conn.SetPongHandler(func(d string) error {
   294  		t.Logf("server got pong: %q", d)
   295  		pongCh <- d
   296  		return nil
   297  	})
   298  	go func() {
   299  		for {
   300  			typ, msg, err := conn.ReadMessage()
   301  			if err != nil {
   302  				return
   303  			}
   304  			t.Logf("server got message (%d): %q", typ, msg)
   305  		}
   306  	}()
   307  
   308  	// Write messages.
   309  	var (
   310  		sendResponse <-chan time.Time
   311  		wantPong     string
   312  	)
   313  	for {
   314  		select {
   315  		case _, open := <-sendPing:
   316  			if !open {
   317  				sendPing = nil
   318  			}
   319  			t.Logf("server sending ping")
   320  			conn.WriteMessage(websocket.PingMessage, []byte("ping"))
   321  			wantPong = "ping"
   322  		case data := <-pongCh:
   323  			if wantPong == "" {
   324  				t.Errorf("unexpected pong")
   325  			} else if data != wantPong {
   326  				t.Errorf("got pong with wrong data %q", data)
   327  			}
   328  			wantPong = ""
   329  			sendResponse = time.NewTimer(200 * time.Millisecond).C
   330  		case <-sendResponse:
   331  			t.Logf("server sending response")
   332  			conn.WriteMessage(websocket.TextMessage, []byte(subNotify))
   333  			sendResponse = nil
   334  		case <-shutdown:
   335  			conn.Close()
   336  			return
   337  		}
   338  	}
   339  }
   340  
   341  func TestWebsocketAuthCheck(t *testing.T) {
   342  	t.Parallel()
   343  
   344  	var (
   345  		srv     = newTestServer("service", new(Service))
   346  		httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"}))
   347  		wsURL   = "ws://testuser:test-PASS_01@" + strings.TrimPrefix(httpsrv.URL, "http://")
   348  	)
   349  	connect := false
   350  	origHandler := httpsrv.Config.Handler
   351  	httpsrv.Config.Handler = http.HandlerFunc(
   352  		func(w http.ResponseWriter, r *http.Request) {
   353  			auth := r.Header.Get("Authorization")
   354  			expectedAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("testuser:test-PASS_01"))
   355  			if r.Method == http.MethodGet && auth == expectedAuth {
   356  				connect = true
   357  				w.WriteHeader(http.StatusSwitchingProtocols)
   358  				return
   359  			}
   360  			if !connect {
   361  				http.Error(w, "connect with authorization not received", http.StatusMethodNotAllowed)
   362  				return
   363  			}
   364  			origHandler.ServeHTTP(w, r)
   365  		})
   366  	defer srv.Stop()
   367  	defer httpsrv.Close()
   368  
   369  	client, err := DialWebsocket(context.Background(), wsURL, "http://example.com")
   370  	if err == nil {
   371  		client.Close()
   372  		t.Fatal("no error for connect with auth header")
   373  	}
   374  	wantErr := wsHandshakeError{websocket.ErrBadHandshake, "101 Switching Protocols"}
   375  	if !reflect.DeepEqual(err, wantErr) {
   376  		t.Fatalf("wrong error for auth header: %q", err)
   377  	}
   378  }