github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/rpc/subscription_test.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  )
    12  
    13  type NotificationTestService struct {
    14  	mu           sync.Mutex
    15  	unsubscribed bool
    16  
    17  	gotHangSubscriptionReq  chan struct{}
    18  	unblockHangSubscription chan struct{}
    19  }
    20  
    21  func (s *NotificationTestService) Echo(i int) int {
    22  	return i
    23  }
    24  
    25  func (s *NotificationTestService) wasUnsubCallbackCalled() bool {
    26  	s.mu.Lock()
    27  	defer s.mu.Unlock()
    28  	return s.unsubscribed
    29  }
    30  
    31  func (s *NotificationTestService) Unsubscribe(subid string) {
    32  	s.mu.Lock()
    33  	s.unsubscribed = true
    34  	s.mu.Unlock()
    35  }
    36  
    37  func (s *NotificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) {
    38  	notifier, supported := NotifierFromContext(ctx)
    39  	if !supported {
    40  		return nil, ErrNotificationsUnsupported
    41  	}
    42  
    43  	// by explicitly creating an subscription we make sure that the subscription id is send back to the client
    44  	// before the first subscription.Notify is called. Otherwise the events might be send before the response
    45  	// for the qct_subscribe method.
    46  	subscription := notifier.CreateSubscription()
    47  
    48  	go func() {
    49  		// test expects n events, if we begin sending event immediately some events
    50  		// will probably be dropped since the subscription ID might not be send to
    51  		// the client.
    52  		time.Sleep(5 * time.Second)
    53  		for i := 0; i < n; i++ {
    54  			if err := notifier.Notify(subscription.ID, val+i); err != nil {
    55  				return
    56  			}
    57  		}
    58  
    59  		select {
    60  		case <-notifier.Closed():
    61  			s.mu.Lock()
    62  			s.unsubscribed = true
    63  			s.mu.Unlock()
    64  		case <-subscription.Err():
    65  			s.mu.Lock()
    66  			s.unsubscribed = true
    67  			s.mu.Unlock()
    68  		}
    69  	}()
    70  
    71  	return subscription, nil
    72  }
    73  
    74  // HangSubscription blocks on s.unblockHangSubscription before
    75  // sending anything.
    76  func (s *NotificationTestService) HangSubscription(ctx context.Context, val int) (*Subscription, error) {
    77  	notifier, supported := NotifierFromContext(ctx)
    78  	if !supported {
    79  		return nil, ErrNotificationsUnsupported
    80  	}
    81  
    82  	s.gotHangSubscriptionReq <- struct{}{}
    83  	<-s.unblockHangSubscription
    84  	subscription := notifier.CreateSubscription()
    85  
    86  	go func() {
    87  		notifier.Notify(subscription.ID, val)
    88  	}()
    89  	return subscription, nil
    90  }
    91  
    92  func TestNotifications(t *testing.T) {
    93  	server := NewServer()
    94  	service := &NotificationTestService{}
    95  
    96  	if err := server.RegisterName("eth", service); err != nil {
    97  		t.Fatalf("unable to register test service %v", err)
    98  	}
    99  
   100  	clientConn, serverConn := net.Pipe()
   101  
   102  	go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions)
   103  
   104  	out := json.NewEncoder(clientConn)
   105  	in := json.NewDecoder(clientConn)
   106  
   107  	n := 5
   108  	val := 12345
   109  	request := map[string]interface{}{
   110  		"id":      1,
   111  		"method":  "qct_subscribe",
   112  		"version": "2.0",
   113  		"params":  []interface{}{"someSubscription", n, val},
   114  	}
   115  
   116  	// create subscription
   117  	if err := out.Encode(request); err != nil {
   118  		t.Fatal(err)
   119  	}
   120  
   121  	var subid string
   122  	response := jsonSuccessResponse{Result: subid}
   123  	if err := in.Decode(&response); err != nil {
   124  		t.Fatal(err)
   125  	}
   126  
   127  	var ok bool
   128  	if _, ok = response.Result.(string); !ok {
   129  		t.Fatalf("expected subscription id, got %T", response.Result)
   130  	}
   131  
   132  	for i := 0; i < n; i++ {
   133  		var notification jsonNotification
   134  		if err := in.Decode(&notification); err != nil {
   135  			t.Fatalf("%v", err)
   136  		}
   137  
   138  		if int(notification.Params.Result.(float64)) != val+i {
   139  			t.Fatalf("expected %d, got %d", val+i, notification.Params.Result)
   140  		}
   141  	}
   142  
   143  	clientConn.Close() // causes notification unsubscribe callback to be called
   144  	time.Sleep(1 * time.Second)
   145  
   146  	if !service.wasUnsubCallbackCalled() {
   147  		t.Error("unsubscribe callback not called after closing connection")
   148  	}
   149  }
   150  
   151  func waitForMessages(t *testing.T, in *json.Decoder, successes chan<- jsonSuccessResponse,
   152  	failures chan<- jsonErrResponse, notifications chan<- jsonNotification, errors chan<- error) {
   153  
   154  	// read and parse server messages
   155  	for {
   156  		var rmsg json.RawMessage
   157  		if err := in.Decode(&rmsg); err != nil {
   158  			return
   159  		}
   160  
   161  		var responses []map[string]interface{}
   162  		if rmsg[0] == '[' {
   163  			if err := json.Unmarshal(rmsg, &responses); err != nil {
   164  				errors <- fmt.Errorf("Received invalid message: %s", rmsg)
   165  				return
   166  			}
   167  		} else {
   168  			var msg map[string]interface{}
   169  			if err := json.Unmarshal(rmsg, &msg); err != nil {
   170  				errors <- fmt.Errorf("Received invalid message: %s", rmsg)
   171  				return
   172  			}
   173  			responses = append(responses, msg)
   174  		}
   175  
   176  		for _, msg := range responses {
   177  			// determine what kind of msg was received and broadcast
   178  			// it to over the corresponding channel
   179  			if _, found := msg["result"]; found {
   180  				successes <- jsonSuccessResponse{
   181  					Version: msg["jsonrpc"].(string),
   182  					Id:      msg["id"],
   183  					Result:  msg["result"],
   184  				}
   185  				continue
   186  			}
   187  			if _, found := msg["error"]; found {
   188  				params := msg["params"].(map[string]interface{})
   189  				failures <- jsonErrResponse{
   190  					Version: msg["jsonrpc"].(string),
   191  					Id:      msg["id"],
   192  					Error:   jsonError{int(params["subscription"].(float64)), params["message"].(string), params["data"]},
   193  				}
   194  				continue
   195  			}
   196  			if _, found := msg["params"]; found {
   197  				params := msg["params"].(map[string]interface{})
   198  				notifications <- jsonNotification{
   199  					Version: msg["jsonrpc"].(string),
   200  					Method:  msg["method"].(string),
   201  					Params:  jsonSubscription{params["subscription"].(string), params["result"]},
   202  				}
   203  				continue
   204  			}
   205  			errors <- fmt.Errorf("Received invalid message: %s", msg)
   206  		}
   207  	}
   208  }
   209  
   210  // TestSubscriptionMultipleNamespaces ensures that subscriptions can exists
   211  // for multiple different namespaces.
   212  func TestSubscriptionMultipleNamespaces(t *testing.T) {
   213  	var (
   214  		namespaces             = []string{"eth", "shh", "bzz"}
   215  		server                 = NewServer()
   216  		service                = NotificationTestService{}
   217  		clientConn, serverConn = net.Pipe()
   218  
   219  		out           = json.NewEncoder(clientConn)
   220  		in            = json.NewDecoder(clientConn)
   221  		successes     = make(chan jsonSuccessResponse)
   222  		failures      = make(chan jsonErrResponse)
   223  		notifications = make(chan jsonNotification)
   224  
   225  		errors = make(chan error, 10)
   226  	)
   227  
   228  	// setup and start server
   229  	for _, namespace := range namespaces {
   230  		if err := server.RegisterName(namespace, &service); err != nil {
   231  			t.Fatalf("unable to register test service %v", err)
   232  		}
   233  	}
   234  
   235  	go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation|OptionSubscriptions)
   236  	defer server.Stop()
   237  
   238  	// wait for message and write them to the given channels
   239  	go waitForMessages(t, in, successes, failures, notifications, errors)
   240  
   241  	// create subscriptions one by one
   242  	n := 3
   243  	for i, namespace := range namespaces {
   244  		request := map[string]interface{}{
   245  			"id":      i,
   246  			"method":  fmt.Sprintf("%s_subscribe", namespace),
   247  			"version": "2.0",
   248  			"params":  []interface{}{"someSubscription", n, i},
   249  		}
   250  
   251  		if err := out.Encode(&request); err != nil {
   252  			t.Fatalf("Could not create subscription: %v", err)
   253  		}
   254  	}
   255  
   256  	// create all subscriptions in 1 batch
   257  	var requests []interface{}
   258  	for i, namespace := range namespaces {
   259  		requests = append(requests, map[string]interface{}{
   260  			"id":      i,
   261  			"method":  fmt.Sprintf("%s_subscribe", namespace),
   262  			"version": "2.0",
   263  			"params":  []interface{}{"someSubscription", n, i},
   264  		})
   265  	}
   266  
   267  	if err := out.Encode(&requests); err != nil {
   268  		t.Fatalf("Could not create subscription in batch form: %v", err)
   269  	}
   270  
   271  	timeout := time.After(30 * time.Second)
   272  	subids := make(map[string]string, 2*len(namespaces))
   273  	count := make(map[string]int, 2*len(namespaces))
   274  
   275  	for {
   276  		done := true
   277  		for id := range count {
   278  			if count, found := count[id]; !found || count < (2*n) {
   279  				done = false
   280  			}
   281  		}
   282  
   283  		if done && len(count) == len(namespaces) {
   284  			break
   285  		}
   286  
   287  		select {
   288  		case err := <-errors:
   289  			t.Fatal(err)
   290  		case suc := <-successes: // subscription created
   291  			subids[namespaces[int(suc.Id.(float64))]] = suc.Result.(string)
   292  		case failure := <-failures:
   293  			t.Errorf("received error: %v", failure.Error)
   294  		case notification := <-notifications:
   295  			if cnt, found := count[notification.Params.Subscription]; found {
   296  				count[notification.Params.Subscription] = cnt + 1
   297  			} else {
   298  				count[notification.Params.Subscription] = 1
   299  			}
   300  		case <-timeout:
   301  			for _, namespace := range namespaces {
   302  				subid, found := subids[namespace]
   303  				if !found {
   304  					t.Errorf("Subscription for '%s' not created", namespace)
   305  					continue
   306  				}
   307  				if count, found := count[subid]; !found || count < n {
   308  					t.Errorf("Didn't receive all notifications (%d<%d) in time for namespace '%s'", count, n, namespace)
   309  				}
   310  			}
   311  			return
   312  		}
   313  	}
   314  }