github.com/qri-io/qri@v0.10.1-0.20220104210721-c771715036cb/lib/websocket/websocket_test.go (about) 1 package websocket 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "net" 8 "net/http" 9 "net/http/httptest" 10 "strings" 11 "testing" 12 13 "github.com/qri-io/qri/auth/key" 14 testkeys "github.com/qri-io/qri/auth/key/test" 15 "github.com/qri-io/qri/auth/token" 16 "github.com/qri-io/qri/event" 17 ) 18 19 func TestWebsocket(t *testing.T) { 20 ctx, cancel := context.WithCancel(context.Background()) 21 defer cancel() 22 23 // create key store & add test key 24 kd := testkeys.GetKeyData(0) 25 ks, err := key.NewMemStore() 26 if err != nil { 27 t.Fatal(err) 28 } 29 if err := ks.AddPubKey(context.Background(), kd.KeyID, kd.PrivKey.GetPublic()); err != nil { 30 t.Fatal(err) 31 } 32 33 // create bus 34 bus := event.NewBus(ctx) 35 36 subsCount := bus.NumSubscribers() 37 38 // create Handler 39 websocketHandler, err := NewHandler(ctx, bus, ks) 40 if err != nil { 41 t.Fatal(err) 42 } 43 wsh := websocketHandler.(*connections) 44 45 // websockets should subscribe the message handler 46 if bus.NumSubscribers() != subsCount+1 { 47 t.Fatalf("failed to subscribe websocket handlers") 48 } 49 50 // add connection 51 randIDStr := "test_connection_id_str" 52 setIDRand(strings.NewReader(randIDStr)) 53 connID := newID() 54 setIDRand(strings.NewReader(randIDStr)) 55 56 wsh.ConnectionHandler(mockWriterAndRequest()) 57 if _, err := wsh.getConn(connID); err != nil { 58 t.Fatal("ConnectionHandler did not create a connection") 59 } 60 61 // create a token from a private key 62 kd = testkeys.GetKeyData(0) 63 tokenStr, err := token.NewPrivKeyAuthToken(kd.PrivKey, kd.KeyID.String(), 0) 64 if err != nil { 65 t.Fatal(err) 66 } 67 // upgrade connection w/ valid token 68 wsh.subscribeConn(connID, tokenStr) 69 proID := kd.KeyID.String() 70 gotConnIDs, err := wsh.getConnIDs(proID) 71 if err != nil { 72 t.Fatal("connections.subscribeConn did not add profileID or conn to subscriptions map") 73 } 74 if _, ok := gotConnIDs[connID]; !ok { 75 t.Fatalf("connections.subscribeConn added incorrect connID to subscriptions map, expected %q, got %q", connID, gotConnIDs) 76 } 77 78 // unsubscribe connection via profileID 79 wsh.unsubscribeConn(proID, "") 80 if _, err := wsh.getConnIDs(proID); err == nil { 81 t.Fatal("connections.unsubscribeConn did not remove the profileID from the subscription map") 82 } 83 wsc, err := wsh.getConn(connID) 84 if err != nil { 85 t.Fatalf("connection %s not found", connID) 86 } 87 if wsc.profileID != "" { 88 t.Error("connections.unsubscribeConn did not remove the profileID from the conn") 89 } 90 91 // remove the connection 92 wsh.removeConn(connID) 93 if _, err := wsh.getConn(connID); err == nil { 94 t.Fatal("connections.removeConn did not remove the connection from the map of conns") 95 } 96 } 97 98 func TestWebsocketUnsubscribe(t *testing.T) { 99 ctx, cancel := context.WithCancel(context.Background()) 100 defer cancel() 101 102 // create key store & add test key 103 kd := testkeys.GetKeyData(0) 104 ks, err := key.NewMemStore() 105 if err != nil { 106 t.Fatal(err) 107 } 108 if err := ks.AddPubKey(context.Background(), kd.KeyID, kd.PrivKey.GetPublic()); err != nil { 109 t.Fatal(err) 110 } 111 112 // create bus 113 bus := event.NewBus(ctx) 114 115 subsCount := bus.NumSubscribers() 116 117 // create Handler 118 websocketHandler, err := NewHandler(ctx, bus, ks) 119 if err != nil { 120 t.Fatal(err) 121 } 122 wsh := websocketHandler.(*connections) 123 124 // websockets should subscribe the message handler 125 if bus.NumSubscribers() != subsCount+1 { 126 t.Fatalf("failed to subscribe websocket handlers") 127 } 128 129 // add connection 130 randIDStr := "test_connection_id_str" 131 setIDRand(strings.NewReader(randIDStr)) 132 connID := newID() 133 setIDRand(strings.NewReader(randIDStr)) 134 135 wsh.ConnectionHandler(mockWriterAndRequest()) 136 if _, err := wsh.getConn(connID); err != nil { 137 t.Fatal("ConnectionHandler did not create a connection") 138 } 139 140 // create a token from a private key with no profileID 141 kd = testkeys.GetKeyData(0) 142 tokenStr, err := token.NewPrivKeyAuthToken(kd.PrivKey, "", 0) 143 if err != nil { 144 t.Fatal(err) 145 } 146 proID := kd.KeyID.String() 147 148 err = wsh.subscribeConn(connID, tokenStr) 149 if err == nil { 150 t.Fatal("connections.subscribeConn subscribed connection with no profileID to subscriptions map") 151 } 152 _, err = wsh.getConn(connID) 153 if err == nil { 154 t.Fatal("connections.subscribeConn got connection which shouldn't exist") 155 } 156 gotConnIDs, err := wsh.getConnIDs(proID) 157 if err == nil { 158 t.Fatal("connections.subscribeConn contains profileID in subscriptions map") 159 } 160 if _, ok := gotConnIDs[connID]; ok { 161 t.Fatalf("connections.subscribeConn added incorrect connID to subscriptions map, expected %q, got %q", connID, gotConnIDs) 162 } 163 } 164 165 func mockWriterAndRequest() (http.ResponseWriter, *http.Request) { 166 w := mockHijacker{ 167 ResponseWriter: httptest.NewRecorder(), 168 } 169 170 r := httptest.NewRequest("GET", "/", nil) 171 r.Header.Set("Connection", "keep-alive, Upgrade") 172 r.Header.Set("Upgrade", "websocket") 173 r.Header.Set("Sec-WebSocket-Version", "13") 174 r.Header.Set("Sec-WebSocket-Key", "test_key") 175 return w, r 176 } 177 178 type mockHijacker struct { 179 http.ResponseWriter 180 } 181 182 var _ http.Hijacker = mockHijacker{} 183 184 func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { 185 c, _ := net.Pipe() 186 r := bufio.NewReader(strings.NewReader("test_reader")) 187 w := bufio.NewWriter(&bytes.Buffer{}) 188 rw := bufio.NewReadWriter(r, w) 189 return c, rw, nil 190 }