github.com/qri-io/qri@v0.10.1-0.20220104210721-c771715036cb/lib/websocket/websocket_test.go (about)

     1  package websocket
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/qri-io/qri/auth/key"
    14  	testkeys "github.com/qri-io/qri/auth/key/test"
    15  	"github.com/qri-io/qri/auth/token"
    16  	"github.com/qri-io/qri/event"
    17  )
    18  
    19  func TestWebsocket(t *testing.T) {
    20  	ctx, cancel := context.WithCancel(context.Background())
    21  	defer cancel()
    22  
    23  	// create key store & add test key
    24  	kd := testkeys.GetKeyData(0)
    25  	ks, err := key.NewMemStore()
    26  	if err != nil {
    27  		t.Fatal(err)
    28  	}
    29  	if err := ks.AddPubKey(context.Background(), kd.KeyID, kd.PrivKey.GetPublic()); err != nil {
    30  		t.Fatal(err)
    31  	}
    32  
    33  	// create bus
    34  	bus := event.NewBus(ctx)
    35  
    36  	subsCount := bus.NumSubscribers()
    37  
    38  	// create Handler
    39  	websocketHandler, err := NewHandler(ctx, bus, ks)
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  	wsh := websocketHandler.(*connections)
    44  
    45  	// websockets should subscribe the message handler
    46  	if bus.NumSubscribers() != subsCount+1 {
    47  		t.Fatalf("failed to subscribe websocket handlers")
    48  	}
    49  
    50  	// add connection
    51  	randIDStr := "test_connection_id_str"
    52  	setIDRand(strings.NewReader(randIDStr))
    53  	connID := newID()
    54  	setIDRand(strings.NewReader(randIDStr))
    55  
    56  	wsh.ConnectionHandler(mockWriterAndRequest())
    57  	if _, err := wsh.getConn(connID); err != nil {
    58  		t.Fatal("ConnectionHandler did not create a connection")
    59  	}
    60  
    61  	// create a token from a private key
    62  	kd = testkeys.GetKeyData(0)
    63  	tokenStr, err := token.NewPrivKeyAuthToken(kd.PrivKey, kd.KeyID.String(), 0)
    64  	if err != nil {
    65  		t.Fatal(err)
    66  	}
    67  	// upgrade connection w/ valid token
    68  	wsh.subscribeConn(connID, tokenStr)
    69  	proID := kd.KeyID.String()
    70  	gotConnIDs, err := wsh.getConnIDs(proID)
    71  	if err != nil {
    72  		t.Fatal("connections.subscribeConn did not add profileID or conn to subscriptions map")
    73  	}
    74  	if _, ok := gotConnIDs[connID]; !ok {
    75  		t.Fatalf("connections.subscribeConn added incorrect connID to subscriptions map, expected %q, got %q", connID, gotConnIDs)
    76  	}
    77  
    78  	// unsubscribe connection via profileID
    79  	wsh.unsubscribeConn(proID, "")
    80  	if _, err := wsh.getConnIDs(proID); err == nil {
    81  		t.Fatal("connections.unsubscribeConn did not remove the profileID from the subscription map")
    82  	}
    83  	wsc, err := wsh.getConn(connID)
    84  	if err != nil {
    85  		t.Fatalf("connection %s not found", connID)
    86  	}
    87  	if wsc.profileID != "" {
    88  		t.Error("connections.unsubscribeConn did not remove the profileID from the conn")
    89  	}
    90  
    91  	// remove the connection
    92  	wsh.removeConn(connID)
    93  	if _, err := wsh.getConn(connID); err == nil {
    94  		t.Fatal("connections.removeConn did not remove the connection from the map of conns")
    95  	}
    96  }
    97  
    98  func TestWebsocketUnsubscribe(t *testing.T) {
    99  	ctx, cancel := context.WithCancel(context.Background())
   100  	defer cancel()
   101  
   102  	// create key store & add test key
   103  	kd := testkeys.GetKeyData(0)
   104  	ks, err := key.NewMemStore()
   105  	if err != nil {
   106  		t.Fatal(err)
   107  	}
   108  	if err := ks.AddPubKey(context.Background(), kd.KeyID, kd.PrivKey.GetPublic()); err != nil {
   109  		t.Fatal(err)
   110  	}
   111  
   112  	// create bus
   113  	bus := event.NewBus(ctx)
   114  
   115  	subsCount := bus.NumSubscribers()
   116  
   117  	// create Handler
   118  	websocketHandler, err := NewHandler(ctx, bus, ks)
   119  	if err != nil {
   120  		t.Fatal(err)
   121  	}
   122  	wsh := websocketHandler.(*connections)
   123  
   124  	// websockets should subscribe the message handler
   125  	if bus.NumSubscribers() != subsCount+1 {
   126  		t.Fatalf("failed to subscribe websocket handlers")
   127  	}
   128  
   129  	// add connection
   130  	randIDStr := "test_connection_id_str"
   131  	setIDRand(strings.NewReader(randIDStr))
   132  	connID := newID()
   133  	setIDRand(strings.NewReader(randIDStr))
   134  
   135  	wsh.ConnectionHandler(mockWriterAndRequest())
   136  	if _, err := wsh.getConn(connID); err != nil {
   137  		t.Fatal("ConnectionHandler did not create a connection")
   138  	}
   139  
   140  	// create a token from a private key with no profileID
   141  	kd = testkeys.GetKeyData(0)
   142  	tokenStr, err := token.NewPrivKeyAuthToken(kd.PrivKey, "", 0)
   143  	if err != nil {
   144  		t.Fatal(err)
   145  	}
   146  	proID := kd.KeyID.String()
   147  
   148  	err = wsh.subscribeConn(connID, tokenStr)
   149  	if err == nil {
   150  		t.Fatal("connections.subscribeConn subscribed connection with no profileID to subscriptions map")
   151  	}
   152  	_, err = wsh.getConn(connID)
   153  	if err == nil {
   154  		t.Fatal("connections.subscribeConn got connection which shouldn't exist")
   155  	}
   156  	gotConnIDs, err := wsh.getConnIDs(proID)
   157  	if err == nil {
   158  		t.Fatal("connections.subscribeConn contains profileID in subscriptions map")
   159  	}
   160  	if _, ok := gotConnIDs[connID]; ok {
   161  		t.Fatalf("connections.subscribeConn added incorrect connID to subscriptions map, expected %q, got %q", connID, gotConnIDs)
   162  	}
   163  }
   164  
   165  func mockWriterAndRequest() (http.ResponseWriter, *http.Request) {
   166  	w := mockHijacker{
   167  		ResponseWriter: httptest.NewRecorder(),
   168  	}
   169  
   170  	r := httptest.NewRequest("GET", "/", nil)
   171  	r.Header.Set("Connection", "keep-alive, Upgrade")
   172  	r.Header.Set("Upgrade", "websocket")
   173  	r.Header.Set("Sec-WebSocket-Version", "13")
   174  	r.Header.Set("Sec-WebSocket-Key", "test_key")
   175  	return w, r
   176  }
   177  
   178  type mockHijacker struct {
   179  	http.ResponseWriter
   180  }
   181  
   182  var _ http.Hijacker = mockHijacker{}
   183  
   184  func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
   185  	c, _ := net.Pipe()
   186  	r := bufio.NewReader(strings.NewReader("test_reader"))
   187  	w := bufio.NewWriter(&bytes.Buffer{})
   188  	rw := bufio.NewReadWriter(r, w)
   189  	return c, rw, nil
   190  }