github.com/ethereum/go-ethereum@v1.14.3/rpc/subscription_test.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"encoding/json"
    23  	"fmt"
    24  	"io"
    25  	"math/big"
    26  	"net"
    27  	"strings"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/ethereum/go-ethereum/common"
    32  	"github.com/ethereum/go-ethereum/core/types"
    33  )
    34  
    35  func TestNewID(t *testing.T) {
    36  	hexchars := "0123456789ABCDEFabcdef"
    37  	for i := 0; i < 100; i++ {
    38  		id := string(NewID())
    39  		if !strings.HasPrefix(id, "0x") {
    40  			t.Fatalf("invalid ID prefix, want '0x...', got %s", id)
    41  		}
    42  
    43  		id = id[2:]
    44  		if len(id) == 0 || len(id) > 32 {
    45  			t.Fatalf("invalid ID length, want len(id) > 0 && len(id) <= 32), got %d", len(id))
    46  		}
    47  
    48  		for i := 0; i < len(id); i++ {
    49  			if strings.IndexByte(hexchars, id[i]) == -1 {
    50  				t.Fatalf("unexpected byte, want any valid hex char, got %c", id[i])
    51  			}
    52  		}
    53  	}
    54  }
    55  
    56  func TestSubscriptions(t *testing.T) {
    57  	var (
    58  		namespaces        = []string{"eth", "bzz"}
    59  		service           = &notificationTestService{}
    60  		subCount          = len(namespaces)
    61  		notificationCount = 3
    62  
    63  		server                 = NewServer()
    64  		clientConn, serverConn = net.Pipe()
    65  		out                    = json.NewEncoder(clientConn)
    66  		in                     = json.NewDecoder(clientConn)
    67  		successes              = make(chan subConfirmation)
    68  		notifications          = make(chan subscriptionResult)
    69  		errors                 = make(chan error, subCount*notificationCount+1)
    70  	)
    71  
    72  	// setup and start server
    73  	for _, namespace := range namespaces {
    74  		if err := server.RegisterName(namespace, service); err != nil {
    75  			t.Fatalf("unable to register test service %v", err)
    76  		}
    77  	}
    78  	go server.ServeCodec(NewCodec(serverConn), 0)
    79  	defer server.Stop()
    80  
    81  	// wait for message and write them to the given channels
    82  	go waitForMessages(in, successes, notifications, errors)
    83  
    84  	// create subscriptions one by one
    85  	for i, namespace := range namespaces {
    86  		request := map[string]interface{}{
    87  			"id":      i,
    88  			"method":  fmt.Sprintf("%s_subscribe", namespace),
    89  			"jsonrpc": "2.0",
    90  			"params":  []interface{}{"someSubscription", notificationCount, i},
    91  		}
    92  		if err := out.Encode(&request); err != nil {
    93  			t.Fatalf("Could not create subscription: %v", err)
    94  		}
    95  	}
    96  
    97  	timeout := time.After(30 * time.Second)
    98  	subids := make(map[string]string, subCount)
    99  	count := make(map[string]int, subCount)
   100  	allReceived := func() bool {
   101  		done := len(count) == subCount
   102  		for _, c := range count {
   103  			if c < notificationCount {
   104  				done = false
   105  			}
   106  		}
   107  		return done
   108  	}
   109  	for !allReceived() {
   110  		select {
   111  		case confirmation := <-successes: // subscription created
   112  			subids[namespaces[confirmation.reqid]] = string(confirmation.subid)
   113  		case notification := <-notifications:
   114  			count[notification.ID]++
   115  		case err := <-errors:
   116  			t.Fatal(err)
   117  		case <-timeout:
   118  			for _, namespace := range namespaces {
   119  				subid, found := subids[namespace]
   120  				if !found {
   121  					t.Errorf("subscription for %q not created", namespace)
   122  					continue
   123  				}
   124  				if count, found := count[subid]; !found || count < notificationCount {
   125  					t.Errorf("didn't receive all notifications (%d<%d) in time for namespace %q", count, notificationCount, namespace)
   126  				}
   127  			}
   128  			t.Fatal("timed out")
   129  		}
   130  	}
   131  }
   132  
   133  // This test checks that unsubscribing works.
   134  func TestServerUnsubscribe(t *testing.T) {
   135  	p1, p2 := net.Pipe()
   136  	defer p2.Close()
   137  
   138  	// Start the server.
   139  	server := newTestServer()
   140  	service := &notificationTestService{unsubscribed: make(chan string, 1)}
   141  	server.RegisterName("nftest2", service)
   142  	go server.ServeCodec(NewCodec(p1), 0)
   143  
   144  	// Subscribe.
   145  	p2.SetDeadline(time.Now().Add(10 * time.Second))
   146  	p2.Write([]byte(`{"jsonrpc":"2.0","id":1,"method":"nftest2_subscribe","params":["someSubscription",0,10]}`))
   147  
   148  	// Handle received messages.
   149  	var (
   150  		resps         = make(chan subConfirmation)
   151  		notifications = make(chan subscriptionResult)
   152  		errors        = make(chan error, 1)
   153  	)
   154  	go waitForMessages(json.NewDecoder(p2), resps, notifications, errors)
   155  
   156  	// Receive the subscription ID.
   157  	var sub subConfirmation
   158  	select {
   159  	case sub = <-resps:
   160  	case err := <-errors:
   161  		t.Fatal(err)
   162  	}
   163  
   164  	// Unsubscribe and check that it is handled on the server side.
   165  	p2.Write([]byte(`{"jsonrpc":"2.0","method":"nftest2_unsubscribe","params":["` + sub.subid + `"]}`))
   166  	for {
   167  		select {
   168  		case id := <-service.unsubscribed:
   169  			if id != string(sub.subid) {
   170  				t.Errorf("wrong subscription ID unsubscribed")
   171  			}
   172  			return
   173  		case err := <-errors:
   174  			t.Fatal(err)
   175  		case <-notifications:
   176  			// drop notifications
   177  		}
   178  	}
   179  }
   180  
   181  type subConfirmation struct {
   182  	reqid int
   183  	subid ID
   184  }
   185  
   186  // waitForMessages reads RPC messages from 'in' and dispatches them into the given channels.
   187  // It stops if there is an error.
   188  func waitForMessages(in *json.Decoder, successes chan subConfirmation, notifications chan subscriptionResult, errors chan error) {
   189  	for {
   190  		resp, notification, err := readAndValidateMessage(in)
   191  		if err != nil {
   192  			errors <- err
   193  			return
   194  		} else if resp != nil {
   195  			successes <- *resp
   196  		} else {
   197  			notifications <- *notification
   198  		}
   199  	}
   200  }
   201  
   202  func readAndValidateMessage(in *json.Decoder) (*subConfirmation, *subscriptionResult, error) {
   203  	var msg jsonrpcMessage
   204  	if err := in.Decode(&msg); err != nil {
   205  		return nil, nil, fmt.Errorf("decode error: %v", err)
   206  	}
   207  	switch {
   208  	case msg.isNotification():
   209  		var res subscriptionResult
   210  		if err := json.Unmarshal(msg.Params, &res); err != nil {
   211  			return nil, nil, fmt.Errorf("invalid subscription result: %v", err)
   212  		}
   213  		return nil, &res, nil
   214  	case msg.isResponse():
   215  		var c subConfirmation
   216  		if msg.Error != nil {
   217  			return nil, nil, msg.Error
   218  		} else if err := json.Unmarshal(msg.Result, &c.subid); err != nil {
   219  			return nil, nil, fmt.Errorf("invalid response: %v", err)
   220  		} else {
   221  			json.Unmarshal(msg.ID, &c.reqid)
   222  			return &c, nil, nil
   223  		}
   224  	default:
   225  		return nil, nil, fmt.Errorf("unrecognized message: %v", msg)
   226  	}
   227  }
   228  
   229  type mockConn struct {
   230  	enc *json.Encoder
   231  }
   232  
   233  // writeJSON writes a message to the connection.
   234  func (c *mockConn) writeJSON(ctx context.Context, msg interface{}, isError bool) error {
   235  	return c.enc.Encode(msg)
   236  }
   237  
   238  // closed returns a channel which is closed when the connection is closed.
   239  func (c *mockConn) closed() <-chan interface{} { return nil }
   240  
   241  // remoteAddr returns the peer address of the connection.
   242  func (c *mockConn) remoteAddr() string { return "" }
   243  
   244  // BenchmarkNotify benchmarks the performance of notifying a subscription.
   245  func BenchmarkNotify(b *testing.B) {
   246  	id := ID("test")
   247  	notifier := &Notifier{
   248  		h:         &handler{conn: &mockConn{json.NewEncoder(io.Discard)}},
   249  		sub:       &Subscription{ID: id},
   250  		activated: true,
   251  	}
   252  	msg := &types.Header{
   253  		ParentHash: common.HexToHash("0x01"),
   254  		Number:     big.NewInt(100),
   255  	}
   256  	b.ResetTimer()
   257  	for i := 0; i < b.N; i++ {
   258  		notifier.Notify(id, msg)
   259  	}
   260  }
   261  
   262  func TestNotify(t *testing.T) {
   263  	out := new(bytes.Buffer)
   264  	id := ID("test")
   265  	notifier := &Notifier{
   266  		h:         &handler{conn: &mockConn{json.NewEncoder(out)}},
   267  		sub:       &Subscription{ID: id},
   268  		activated: true,
   269  	}
   270  	msg := &types.Header{
   271  		ParentHash: common.HexToHash("0x01"),
   272  		Number:     big.NewInt(100),
   273  	}
   274  	notifier.Notify(id, msg)
   275  	have := strings.TrimSpace(out.String())
   276  	want := `{"jsonrpc":"2.0","method":"_subscription","params":{"subscription":"test","result":{"parentHash":"0x0000000000000000000000000000000000000000000000000000000000000001","sha3Uncles":"0x0000000000000000000000000000000000000000000000000000000000000000","miner":"0x0000000000000000000000000000000000000000","stateRoot":"0x0000000000000000000000000000000000000000000000000000000000000000","transactionsRoot":"0x0000000000000000000000000000000000000000000000000000000000000000","receiptsRoot":"0x0000000000000000000000000000000000000000000000000000000000000000","logsBloom":"0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000","difficulty":null,"number":"0x64","gasLimit":"0x0","gasUsed":"0x0","timestamp":"0x0","extraData":"0x","mixHash":"0x0000000000000000000000000000000000000000000000000000000000000000","nonce":"0x0000000000000000","baseFeePerGas":null,"withdrawalsRoot":null,"blobGasUsed":null,"excessBlobGas":null,"parentBeaconBlockRoot":null,"hash":"0xe5fb877dde471b45b9742bb4bb4b3d74a761e2fb7cb849a3d2b687eed90fb604"}}}`
   277  	if have != want {
   278  		t.Errorf("have:\n%v\nwant:\n%v\n", have, want)
   279  	}
   280  }