decred.org/dcrdex@v1.0.5/client/comms/wsconn_test.go (about) 1 package comms 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/elliptic" 7 "encoding/hex" 8 "errors" 9 "fmt" 10 "net" 11 "net/http" 12 "os" 13 "runtime" 14 "sync" 15 "sync/atomic" 16 "testing" 17 "time" 18 19 "decred.org/dcrdex/dex" 20 "decred.org/dcrdex/dex/msgjson" 21 "github.com/decred/dcrd/certgen" 22 "github.com/gorilla/websocket" 23 ) 24 25 var tLogger = dex.StdOutLogger("conn_TEST", dex.LevelTrace) 26 27 func makeRequest(id uint64, route string, msg any) *msgjson.Message { 28 req, _ := msgjson.NewRequest(id, route, msg) 29 return req 30 } 31 32 // genCertPair generates a key/cert pair to the paths provided. 33 func genCertPair(certFile, keyFile string, altDNSNames []string) error { 34 tLogger.Infof("Generating TLS certificates...") 35 36 org := "dcrdex autogenerated cert" 37 validUntil := time.Now().Add(10 * 365 * 24 * time.Hour) 38 cert, key, err := certgen.NewTLSCertPair(elliptic.P521(), org, 39 validUntil, altDNSNames) 40 if err != nil { 41 return err 42 } 43 44 // Write cert and key files. 45 if err = os.WriteFile(certFile, cert, 0644); err != nil { 46 return err 47 } 48 if err = os.WriteFile(keyFile, key, 0600); err != nil { 49 os.Remove(certFile) 50 return err 51 } 52 53 tLogger.Infof("Done generating TLS certificates") 54 return nil 55 } 56 57 func TestWsConn(t *testing.T) { 58 // Must wait for goroutines, especially the ones that capture t. 59 var wg sync.WaitGroup 60 defer wg.Wait() 61 62 upgrader := websocket.Upgrader{} 63 64 pingCh := make(chan struct{}) 65 readPumpCh := make(chan any) 66 writePumpCh := make(chan *msgjson.Message) 67 ctx, cancel := context.WithCancel(context.Background()) 68 defer cancel() 69 70 type conn struct { 71 sync.WaitGroup 72 *websocket.Conn 73 } 74 var clientMtx sync.Mutex 75 clients := make(map[uint64]*conn) 76 77 // server.Shutdown does not wait for hijacked connections, and pong handler 78 // uses t.Logf. 79 defer func() { 80 clientMtx.Lock() 81 for id, h := range clients { 82 h.Close() 83 h.Wait() 84 delete(clients, id) 85 } 86 clientMtx.Unlock() 87 }() 88 89 var id uint64 90 // server's "/ws" handler 91 handler := func(w http.ResponseWriter, r *http.Request) { 92 t.Helper() 93 id := atomic.AddUint64(&id, 1) // shadow id 94 hCtx, hCancel := context.WithCancel(ctx) 95 96 c, err := upgrader.Upgrade(w, r, nil) 97 if err != nil { 98 t.Errorf("unable to upgrade http connection: %s", err) 99 } 100 101 ch := &conn{Conn: c} 102 clientMtx.Lock() 103 clients[id] = ch 104 clientMtx.Unlock() 105 106 c.SetPongHandler(func(string) error { 107 t.Logf("handler #%d: pong received", id) 108 return nil 109 }) 110 111 ch.Add(1) 112 go func() { 113 defer ch.Done() 114 for { 115 select { 116 case <-pingCh: 117 err := c.WriteControl(websocket.PingMessage, []byte{}, 118 time.Now().Add(writeWait)) 119 if err != nil { 120 if hCtx.Err() == nil { 121 // Only a failure if the server isn't shutting down. 122 t.Errorf("handler #%d: ping error: %v", id, err) 123 } 124 return 125 } 126 127 t.Logf("handler #%d: ping sent", id) 128 129 case msg := <-readPumpCh: 130 err := c.WriteJSON(msg) 131 if err != nil { 132 t.Errorf("handler #%d: write error: %v", id, err) 133 return 134 } 135 136 case <-hCtx.Done(): 137 return 138 } 139 } 140 }() 141 142 ch.Add(1) 143 go func() { 144 defer ch.Done() 145 for { 146 mType, message, err := c.ReadMessage() 147 if err != nil { 148 hCancel() 149 c.Close() 150 151 // If the context has been canceled, don't do anything. 152 if hCtx.Err() != nil { 153 return 154 } 155 156 if websocket.IsCloseError(err, websocket.CloseNormalClosure) { 157 // Terminate on a normal close message. 158 return 159 } 160 161 t.Errorf("handler #%d: read error: %v\n", id, err) 162 return 163 } 164 165 if mType == websocket.TextMessage { 166 msg, err := msgjson.DecodeMessage(message) 167 if err != nil { 168 t.Errorf("handler #%d: decode error: %v", id, err) 169 continue // Don't hang up. 170 } 171 172 writePumpCh <- msg 173 } 174 } 175 }() 176 } 177 178 certFile, err := os.CreateTemp("", "certfile") 179 if err != nil { 180 t.Fatalf("unable to create temp certfile: %s", err) 181 } 182 certFile.Close() 183 defer os.Remove(certFile.Name()) 184 185 keyFile, err := os.CreateTemp("", "keyfile") 186 if err != nil { 187 t.Fatalf("unable to create temp keyfile: %s", err) 188 } 189 keyFile.Close() 190 defer os.Remove(keyFile.Name()) 191 192 err = genCertPair(certFile.Name(), keyFile.Name(), nil) 193 if err != nil { 194 t.Fatal(err) 195 } 196 197 certB, err := os.ReadFile(certFile.Name()) 198 if err != nil { 199 t.Fatalf("file reading error: %v", err) 200 } 201 202 host := "127.0.0.1:0" 203 mux := http.NewServeMux() 204 mux.HandleFunc("/ws", handler) 205 206 // http server for the connect and upgrade 207 server := &http.Server{ 208 WriteTimeout: time.Second * 10, 209 ReadTimeout: time.Second * 10, 210 Addr: host, 211 Handler: mux, 212 } 213 defer server.Shutdown(context.Background()) 214 215 wg.Add(1) 216 serverReady := make(chan error, 1) 217 go func() { 218 defer wg.Done() 219 220 ln, err := net.Listen("tcp", server.Addr) 221 if err != nil { 222 serverReady <- err 223 return 224 } 225 defer ln.Close() 226 //log.Info(ln.Addr().(*net.TCPAddr).Port) 227 host = ln.Addr().String() 228 serverReady <- nil // after setting host 229 230 err = server.ServeTLS(ln, certFile.Name(), keyFile.Name()) 231 if err != nil { 232 fmt.Println(err) 233 } 234 }() 235 236 // wait for server to start listening before connecting 237 err = <-serverReady 238 if err != nil { 239 t.Fatal(err) 240 } 241 242 const pingWait = 500 * time.Millisecond 243 setupWsConn := func(cert []byte) (*wsConn, error) { 244 cfg := &WsCfg{ 245 URL: "wss://" + host + "/ws", 246 PingWait: pingWait, 247 Cert: cert, 248 Logger: tLogger, 249 } 250 conn, err := NewWsConn(cfg) 251 if err != nil { 252 return nil, err 253 } 254 return conn.(*wsConn), nil 255 } 256 257 // test no cert error 258 noCertConn, err := setupWsConn(nil) 259 if err != nil { 260 t.Fatal(err) 261 } 262 noCertConnMaster := dex.NewConnectionMaster(noCertConn) 263 err = noCertConnMaster.Connect(ctx) 264 noCertConnMaster.Disconnect() 265 if err == nil || !errors.Is(err, ErrCertRequired) { 266 t.Fatalf("failed to get ErrCertRequired for no cert connection, got %v", err) 267 } 268 269 // test invalid cert error 270 _, err = setupWsConn([]byte("invalid cert")) 271 if err == nil || !errors.Is(err, ErrInvalidCert) { 272 t.Fatalf("failed to get ErrInvalidCert for invalid cert connection, got %v", err) 273 } 274 275 // connect with cert 276 wsc, err := setupWsConn(certB) 277 if err != nil { 278 t.Fatal(err) 279 } 280 waiter := dex.NewConnectionMaster(wsc) 281 err = waiter.Connect(ctx) 282 if err != nil { 283 t.Fatalf("Connect: %v", err) 284 } 285 286 reconnectAndPing := func() { 287 // Drop the connection and force a reconnect by waiting longer than the 288 // read deadline (the ping wait), plus a bit extra to allow the timeout 289 // to flip off the connection and queue a reconnect. 290 time.Sleep(pingWait * 3 / 2) 291 runtime.Gosched() 292 293 // Wait for a reconnection. 294 for wsc.IsDown() { 295 time.Sleep(time.Millisecond * 10) 296 continue 297 } 298 299 // Send a ping. 300 pingCh <- struct{}{} 301 } 302 303 orderid, _ := hex.DecodeString("ceb09afa675cee31c0f858b94c81bd1a4c2af8c5947d13e544eef772381f2c8d") 304 matchid, _ := hex.DecodeString("7c6b44735e303585d644c713fe0e95897e7e8ba2b9bba98d6d61b70006d3d58c") 305 match := &msgjson.Match{ 306 OrderID: orderid, 307 MatchID: matchid, 308 Quantity: 20, 309 Rate: 2, 310 Address: "DsiNAJCd2sSazZRU9ViDD334DaLgU1Kse3P", 311 } 312 313 // Ensure a malformed message to the client does not terminate 314 // the connection. 315 readPumpCh <- []byte("{notjson") 316 317 // Send a message to the client. 318 sent := makeRequest(1, msgjson.MatchRoute, match) 319 readPumpCh <- sent 320 321 // Fetch the read source. 322 readSource := wsc.MessageSource() 323 if readSource == nil { 324 t.Fatal("expected a non-nil read source") 325 } 326 327 // Read the message received by the client. 328 received := <-readSource 329 330 // Ensure the received message equal to the sent message. 331 if received.Type != sent.Type { 332 t.Fatalf("expected %v type, got %v", sent.Type, received.Type) 333 } 334 335 if received.Route != sent.Route { 336 t.Fatalf("expected %v route, got %v", sent.Route, received.Route) 337 } 338 339 if received.ID != sent.ID { 340 t.Fatalf("expected %v id, got %v", sent.ID, received.ID) 341 } 342 343 if !bytes.Equal(received.Payload, sent.Payload) { 344 t.Fatal("sent and received payload mismatch") 345 } 346 347 reconnectAndPing() 348 349 coinID := []byte{ 350 0xc3, 0x16, 0x10, 0x33, 0xde, 0x09, 0x6f, 0xd7, 0x4d, 0x90, 0x51, 0xff, 351 0x0b, 0xd9, 0x9e, 0x35, 0x9d, 0xe3, 0x50, 0x80, 0xa3, 0x51, 0x10, 0x81, 352 0xed, 0x03, 0x5f, 0x54, 0x1b, 0x85, 0x0d, 0x43, 0x00, 0x00, 0x00, 0x0a, 353 } 354 355 contract, _ := hex.DecodeString("caf8d277f80f71e4") 356 init := &msgjson.Init{ 357 OrderID: orderid, 358 MatchID: matchid, 359 CoinID: coinID, 360 Contract: contract, 361 } 362 363 // Send a message from the client. 364 mId := wsc.NextID() 365 sent = makeRequest(mId, msgjson.InitRoute, init) 366 handlerRun := false 367 err = wsc.Request(sent, func(*msgjson.Message) { 368 handlerRun = true 369 }) 370 if err != nil { 371 t.Fatalf("unexpected error: %v", err) 372 } 373 374 // Read the message received by the server. 375 received = <-writePumpCh 376 377 // Ensure the received message equal to the sent message. 378 if received.Type != sent.Type { 379 t.Fatalf("expected %v type, got %v", sent.Type, received.Type) 380 } 381 382 if received.Route != sent.Route { 383 t.Fatalf("expected %v route, got %v", sent.Route, received.Route) 384 } 385 386 if received.ID != sent.ID { 387 t.Fatalf("expected %v id, got %v", sent.ID, received.ID) 388 } 389 390 if !bytes.Equal(received.Payload, sent.Payload) { 391 t.Fatal("sent and received payload mismatch") 392 } 393 394 // Ensure the next id is as expected. 395 next := wsc.NextID() 396 if next != 2 { 397 t.Fatalf("expected next id to be %d, got %d", 2, next) 398 } 399 400 // Ensure the request got logged, also unregister the response handler. 401 hndlr := wsc.respHandler(mId) 402 if hndlr == nil { 403 t.Fatalf("no handler found") 404 } 405 hndlr.f(nil) 406 if !handlerRun { 407 t.Fatalf("wrong handler retrieved") 408 } 409 410 // Ensure the response handler is unlogged. 411 hndlr = wsc.respHandler(mId) 412 if hndlr != nil { 413 t.Fatal("found a response handler for an unlogged request id") 414 } 415 416 pingCh <- struct{}{} 417 418 // Ensure malformed request data (a send failure) does not leave a 419 // registered response handler or kill the connection. 420 sent.ID = wsc.NextID() 421 sent.Payload = []byte("{notjson") 422 err = wsc.Request(sent, func(*msgjson.Message) {}) 423 if err == nil { 424 t.Fatalf("expected error with malformed request payload") 425 } 426 427 // Ensure the response handler is unregistered. 428 if wsc.respHandler(mId) != nil { 429 t.Fatal("response handler was still registered") 430 } 431 432 // New request to test expiration. 433 mId = next 434 sent = makeRequest(mId, msgjson.InitRoute, init) 435 expiring := make(chan struct{}, 1) 436 expTime := 50 * time.Millisecond // way shorter than pingWait 437 err = wsc.RequestWithTimeout(sent, func(*msgjson.Message) {}, expTime, func() { 438 expiring <- struct{}{} 439 }) 440 if err != nil { 441 t.Fatalf("unexpected error: %v", err) 442 } 443 <-writePumpCh 444 445 pingCh <- struct{}{} 446 447 // Yield to the comms goroutine in case this machine is poor. 448 runtime.Gosched() 449 select { 450 case <-expiring: 451 case <-time.NewTimer(time.Second).C: // >> expTime 452 t.Fatalf("didn't expire") // conn will be dead by this time without pings 453 } 454 455 // New request to abort on conn shutdown. 456 sent = makeRequest(wsc.NextID(), msgjson.InitRoute, init) 457 expiring = make(chan struct{}, 1) 458 expTime = 20 * time.Second // we're going to cancel first 459 beforeExpire := time.After(2 * time.Second) // enough time for shutdown to call expire func 460 err = wsc.RequestWithTimeout(sent, func(*msgjson.Message) {}, expTime, func() { 461 expiring <- struct{}{} 462 }) 463 if err != nil { 464 t.Fatalf("unexpected error: %v", err) 465 } 466 <-writePumpCh 467 468 pingCh <- struct{}{} 469 470 // Shutdown/Disconnect before expire. 471 time.Sleep(50 * time.Millisecond) // let pings and pongs flush, but it's not a problem if they bomb 472 waiter.Disconnect() 473 474 select { 475 case <-beforeExpire: // much shorter than req timeout 476 t.Error("expire func not called on conn shutdown") 477 case <-expiring: // means aborted if triggered before timeout 478 } 479 480 select { 481 case _, ok := <-readSource: 482 if ok { 483 t.Error("read source should have been closed") 484 } 485 default: 486 t.Error("read source should have been closed") 487 } 488 }