decred.org/dcrdex@v1.0.5/server/comms/server.go (about) 1 // This code is available on the terms of the project LICENSE.md file, 2 // also available online at https://blueoakcouncil.org/license/1.0.0. 3 4 package comms 5 6 import ( 7 "context" 8 "crypto/elliptic" 9 "crypto/tls" 10 "encoding/json" 11 "errors" 12 "fmt" 13 "net" 14 "net/http" 15 "os" 16 "strings" 17 "sync" 18 "sync/atomic" 19 "time" 20 21 "decred.org/dcrdex/dex" 22 "decred.org/dcrdex/dex/msgjson" 23 "decred.org/dcrdex/dex/ws" 24 "github.com/decred/dcrd/certgen" 25 "github.com/go-chi/chi/v5" 26 "github.com/go-chi/chi/v5/middleware" 27 "golang.org/x/time/rate" 28 ) 29 30 const ( 31 // rpcTimeoutSeconds is the number of seconds a connection to the 32 // RPC server is allowed to stay open without authenticating before it 33 // is closed. 34 rpcTimeoutSeconds = 10 35 36 // rpcMaxClients is the maximum number of active websocket connections 37 // allowed. 38 rpcMaxClients = 10000 39 40 // rpcMaxConnsPerIP is the maximum number of active websocket connections 41 // allowed per IP, loopback excluded. 42 rpcMaxConnsPerIP = 8 43 44 // banishTime is the default duration of a client quarantine. 45 banishTime = time.Hour 46 47 // Per-ip rate limits for market data API routes. 48 ipMaxRatePerSec = 1 49 ipMaxBurstSize = 5 50 51 // Per-websocket-connection limits in requests per second. Rate should be a 52 // reasonable sustained rate, while burst should consider bulk reconnect 53 // operations. Consider which routes are authenticated when setting these. 54 wsRateStatus, wsBurstStatus = 10, 500 // order_status and match_status (combined) 55 wsRateOrder, wsBurstOrder = 5, 100 // market, limit, and cancel (combined) 56 wsRateInfo, wsBurstInfo = 10, 200 // low-cost route limiter for: config, fee_rate, spots, candles (combined) 57 wsRateSubs, wsBurstSubs = 1 / 2.0, 100 // subscriptions: orderbook and price feed (combined) 58 wsRateConnect, wsBurstConnect = 1 / 5.0, 100 // connect, account discovery requires bursts - (*Core).discoverAccount 59 // The cumulative rates below would need to be less than sum of above to 60 // actually trip unless it is also applied to unspecified routes. 61 wsRateTotal, wsBurstTotal = 40, 1000 62 ) 63 64 var ( 65 // Time allowed to read the next pong message from the peer. The default is 66 // intended for production, but leaving as a var instead of const to 67 // facilitate testing. This is the websocket read timeout set by the pong 68 // handler. The first read deadline is set by the ws.WSLink. 69 pongWait = 20 * time.Second 70 71 // Send pings to peer with this period. Must be less than pongWait. The 72 // default is intended for production, but leaving as a var instead of const 73 // to facilitate testing. 74 pingPeriod = (pongWait * 9) / 10 // i.e. 18 sec 75 76 // globalHTTPRateLimiter is a limit on the global HTTP request limit. The 77 // global rate limiter is like a rudimentary auto-spam filter for 78 // non-critical routes, including all routes registered as HTTP routes. 79 globalHTTPRateLimiter = rate.NewLimiter(100, 1000) // rate per sec, max burst 80 81 // ipHTTPRateLimiter is a per-client rate limiter for the HTTP endpoints 82 // requests and httpRoutes (the market data API). The Server manages 83 // separate limiters used with the websocket routes, rpcRoutes. 84 ipHTTPRateLimiter = make(map[dex.IPKey]*ipRateLimiter) 85 rateLimiterMtx sync.RWMutex 86 ) 87 88 var idCounter uint64 89 90 // ipRateLimiter is used to track an IPs HTTP request rate. 91 type ipRateLimiter struct { 92 *rate.Limiter 93 lastHit time.Time 94 } 95 96 // Get an ipRateLimiter for the IP. Creates a new one if it doesn't exist. This 97 // is for use with the HTTP endpoints and httpRoutes (the data API), not the 98 // websocket request routes in rpcRoutes. 99 func getIPLimiter(ip dex.IPKey) *ipRateLimiter { 100 rateLimiterMtx.Lock() 101 defer rateLimiterMtx.Unlock() 102 limiter := ipHTTPRateLimiter[ip] 103 if limiter != nil { 104 limiter.lastHit = time.Now() 105 return limiter 106 } 107 limiter = &ipRateLimiter{ 108 Limiter: rate.NewLimiter(ipMaxRatePerSec, ipMaxBurstSize), 109 lastHit: time.Now(), 110 } 111 ipHTTPRateLimiter[ip] = limiter 112 return limiter 113 } 114 115 // NextID returns a unique ID to identify a request-type message. 116 func NextID() uint64 { 117 return atomic.AddUint64(&idCounter, 1) 118 } 119 120 // MsgHandler describes a handler for a specific message route. 121 type MsgHandler func(Link, *msgjson.Message) *msgjson.Error 122 123 // HTTPHandler describes a handler for an HTTP route. 124 type HTTPHandler func(thing any) (any, error) 125 126 // Route registers a handler for a specified route. The handler map is global 127 // and has no mutex protection. All calls to Route should be done before the 128 // Server is started. 129 func (s *Server) Route(route string, handler MsgHandler) { 130 if route == "" { 131 panic("Route: route is empty string") 132 } 133 _, alreadyHave := s.rpcRoutes[route] 134 if alreadyHave { 135 panic(fmt.Sprintf("Route: double registration: %s", route)) 136 } 137 s.rpcRoutes[route] = handler 138 } 139 140 func (s *Server) RegisterHTTP(route string, handler HTTPHandler) { 141 if route == "" { 142 panic("RegisterHTTP: route is empty string") 143 } 144 _, alreadyHave := s.httpRoutes[route] 145 if alreadyHave { 146 panic(fmt.Sprintf("RegisterHTTP: double registration: %s", route)) 147 } 148 s.httpRoutes[route] = handler 149 } 150 151 // The RPCConfig is the server configuration settings and the only argument 152 // to the server's constructor. 153 type RPCConfig struct { 154 // HiddenServiceAddr is the local address to which connections from the 155 // local hidden service will connect, e.g. 127.0.0.1:7252. This is not the 156 // .onion address of the hidden service. The TLS key pairs do not apply to 157 // these connections since TLS is not used on the hidden service's listener. 158 // This corresponds to the last component of a HiddenServicePort line in a 159 // torrc config file. e.g. HiddenServicePort 7232 127.0.0.1:7252. Clients 160 // would specify the port preceding this address in the above statement. 161 HiddenServiceAddr string 162 // ListenAddrs are the addresses on which the server will listen. 163 ListenAddrs []string 164 // The location of the TLS keypair files. If they are not already at the 165 // specified location, a keypair with a self-signed certificate will be 166 // generated and saved to these locations. 167 RPCKey string 168 RPCCert string 169 NoTLS bool 170 // AltDNSNames specifies allowable request addresses for an auto-generated 171 // TLS keypair. Changing AltDNSNames does not force the keypair to be 172 // regenerated. To regenerate, delete or move the old files. 173 AltDNSNames []string 174 // DisableDataAPI will disable all traffic to the HTTP data API routes. 175 DisableDataAPI bool 176 } 177 178 // allower is satisfied by rate.Limiter. 179 type allower interface { 180 Allow() bool 181 } 182 183 // routeLimiter contains a set of rate limiters for individual routes, and a 184 // cumulative limiter applied after defined routers are applied. No limiter is 185 // applied to an unspecified route. 186 type routeLimiter struct { 187 routes map[string]allower 188 cumulative allower // only used for defined routes 189 } 190 191 func (rl *routeLimiter) allow(route string) bool { 192 // To apply the cumulative limiter to all routes including those without 193 // their own limiter, we would apply it here. Maybe go with this if we are 194 // confident it's not going to interfere with init/redeem or others. 195 // if !rl.cumulative.Allow() { 196 // return false 197 // } 198 limiter := rl.routes[route] 199 if limiter == nil { 200 return true // free 201 } 202 return rl.cumulative.Allow() && limiter.Allow() 203 } 204 205 // newRouteLimiter creates a route-based rate limiter. It should be applied to 206 // all connections from a given IP address. 207 func newRouteLimiter() *routeLimiter { 208 // Some routes share a limiter to aggregate request stats: 209 statusLimiter := rate.NewLimiter(wsRateStatus, wsBurstStatus) 210 orderLimiter := rate.NewLimiter(wsRateOrder, wsBurstOrder) 211 infoLimiter := rate.NewLimiter(wsRateInfo, wsBurstInfo) 212 marketSubsLimiter := rate.NewLimiter(wsRateSubs, wsBurstSubs) 213 return &routeLimiter{ 214 cumulative: rate.NewLimiter(wsRateTotal, wsBurstTotal), 215 routes: map[string]allower{ 216 // Connect (authorize) route 217 msgjson.ConnectRoute: rate.NewLimiter(wsRateConnect, wsBurstConnect), 218 // Status checking of matches and orders 219 msgjson.MatchStatusRoute: statusLimiter, 220 msgjson.OrderStatusRoute: statusLimiter, 221 // Order submission 222 msgjson.LimitRoute: orderLimiter, 223 msgjson.MarketRoute: orderLimiter, 224 msgjson.CancelRoute: orderLimiter, 225 // Order book and price feed subscriptions 226 msgjson.OrderBookRoute: marketSubsLimiter, 227 msgjson.PriceFeedRoute: marketSubsLimiter, 228 // Config, fee rate, spot prices, and candles 229 msgjson.FeeRateRoute: infoLimiter, 230 msgjson.ConfigRoute: infoLimiter, 231 msgjson.SpotsRoute: infoLimiter, 232 msgjson.CandlesRoute: infoLimiter, 233 }, 234 } 235 } 236 237 // ipWsLimiter facilitates connection counting for a source IP address to 238 // aggregate requests stats by a single rate limiter. 239 type ipWsLimiter struct { 240 conns int64 241 cleaner *time.Timer // when conns drops to zero, set a cleanup timer 242 *routeLimiter 243 } 244 245 // Server is a low-level communications hub. It supports websocket clients 246 // and an HTTP API. 247 type Server struct { 248 mux *chi.Mux 249 // One listener for each address specified at (RPCConfig).ListenAddrs. 250 listeners []net.Listener 251 252 // The client map indexes each wsLink by its id. 253 clientMtx sync.RWMutex 254 clients map[uint64]*wsLink 255 counter uint64 // for generating unique client IDs 256 257 // wsLimiters manages per-IP per-route websocket connection request rate 258 // limiters that are not subject to server-wide rate limits or affected by 259 // disabling of the data API (Server.dataEnabled). 260 wsLimiterMtx sync.Mutex // the map and the fields of each limiter 261 wsLimiters map[dex.IPKey]*ipWsLimiter 262 v6Prefixes map[dex.IPKey]int // just debugging presently 263 264 // The quarantine map maps IP addresses to a time in which the quarantine will 265 // be lifted. 266 banMtx sync.RWMutex 267 quarantine map[dex.IPKey]time.Time 268 269 dataEnabled uint32 // atomic 270 271 // rpcRoutes maps message routes to the handlers. 272 rpcRoutes map[string]MsgHandler 273 // httpRoutes maps HTTP routes to the handlers. 274 httpRoutes map[string]HTTPHandler 275 } 276 277 // NewServer constructs a Server that should be started with Run. The server is 278 // TLS-only, and will generate a key pair with a self-signed certificate if one 279 // is not provided as part of the RPCConfig. The server also maintains a 280 // IP-based quarantine to short-circuit to an error response for misbehaving 281 // clients, if necessary. 282 func NewServer(cfg *RPCConfig) (*Server, error) { 283 284 var tlsConfig *tls.Config 285 if !cfg.NoTLS { 286 // Prepare the TLS configuration. 287 keyExists := dex.FileExists(cfg.RPCKey) 288 certExists := dex.FileExists(cfg.RPCCert) 289 if certExists == !keyExists { 290 return nil, fmt.Errorf("missing cert pair file") 291 } 292 if !keyExists && !certExists { 293 err := genCertPair(cfg.RPCCert, cfg.RPCKey, cfg.AltDNSNames) 294 if err != nil { 295 return nil, err 296 } 297 } 298 keypair, err := tls.LoadX509KeyPair(cfg.RPCCert, cfg.RPCKey) 299 if err != nil { 300 return nil, err 301 } 302 tlsConfig = &tls.Config{ 303 Certificates: []tls.Certificate{keypair}, // TODO: multiple key pairs for virtual hosting 304 MinVersion: tls.VersionTLS12, 305 } 306 } 307 308 // Start with the hidden service listener, if specified. 309 var listeners []net.Listener 310 if cfg.HiddenServiceAddr == "" { 311 listeners = make([]net.Listener, 0, len(cfg.ListenAddrs)) 312 } else { 313 listeners = make([]net.Listener, 0, 1+len(cfg.ListenAddrs)) 314 ipv4ListenAddrs, ipv6ListenAddrs, _, err := parseListeners([]string{cfg.HiddenServiceAddr}) 315 if err != nil { 316 return nil, err 317 } 318 for _, addr := range ipv4ListenAddrs { 319 listener, err := net.Listen("tcp4", addr) 320 if err != nil { 321 return nil, fmt.Errorf("cannot listen on %s: %w", addr, err) 322 } 323 listeners = append(listeners, onionListener{listener}) 324 } 325 for _, addr := range ipv6ListenAddrs { 326 listener, err := net.Listen("tcp6", addr) 327 if err != nil { 328 return nil, fmt.Errorf("cannot listen on %s: %w", addr, err) 329 } 330 listeners = append(listeners, onionListener{listener}) 331 } 332 } 333 334 // Parse the specified listen addresses and create the []net.Listener. 335 ipv4ListenAddrs, ipv6ListenAddrs, _, err := parseListeners(cfg.ListenAddrs) 336 if err != nil { 337 return nil, err 338 } 339 parseListener := func(network, addr string) (err error) { 340 var listener net.Listener 341 if cfg.NoTLS { 342 listener, err = net.Listen(network, addr) 343 } else { 344 listener, err = tls.Listen(network, addr, tlsConfig) 345 } 346 if err != nil { 347 return fmt.Errorf("cannot listen on %s: %w", addr, err) 348 } 349 listeners = append(listeners, listener) 350 return nil 351 } 352 353 for _, addr := range ipv4ListenAddrs { 354 if err := parseListener("tcp4", addr); err != nil { 355 return nil, err 356 } 357 } 358 for _, addr := range ipv6ListenAddrs { 359 if err := parseListener("tcp6", addr); err != nil { 360 return nil, err 361 } 362 } 363 if len(listeners) == 0 { 364 return nil, fmt.Errorf("RPCS: No valid listen address") 365 } 366 var dataEnabled uint32 = 1 367 if cfg.DisableDataAPI { 368 dataEnabled = 0 369 } 370 371 // Create an HTTP router, putting a couple of useful middlewares in place. 372 mux := chi.NewRouter() 373 mux.Use(middleware.RealIP) 374 mux.Use(middleware.Recoverer) 375 376 return &Server{ 377 mux: mux, 378 listeners: listeners, 379 clients: make(map[uint64]*wsLink), 380 wsLimiters: make(map[dex.IPKey]*ipWsLimiter), 381 v6Prefixes: make(map[dex.IPKey]int), 382 quarantine: make(map[dex.IPKey]time.Time), 383 dataEnabled: dataEnabled, 384 rpcRoutes: make(map[string]MsgHandler), 385 httpRoutes: make(map[string]HTTPHandler), 386 }, nil 387 } 388 389 type onionListener struct{ net.Listener } 390 391 // Run starts the server. Run should be called only after all routes are 392 // registered. 393 func (s *Server) Run(ctx context.Context) { 394 mux := s.mux 395 var wg sync.WaitGroup 396 397 // Websocket endpoint. 398 mux.Get("/ws", func(w http.ResponseWriter, r *http.Request) { 399 ip := dex.NewIPKey(r.RemoteAddr) 400 if s.isQuarantined(ip) { 401 http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) 402 return 403 } 404 if s.clientCount() >= rpcMaxClients { 405 http.Error(w, "server at maximum capacity", http.StatusServiceUnavailable) 406 return 407 } 408 409 // Check websocket connection count for this IP before upgrading the 410 // conn so we can send an HTTP error code, but check again after 411 // upgrade/hijack so they cannot initiate many simultaneously. 412 if s.ipConnCount(ip) >= rpcMaxConnsPerIP { 413 http.Error(w, "too many connections from your address", http.StatusServiceUnavailable) 414 return 415 } 416 417 wsConn, err := ws.NewConnection(w, r, pongWait) 418 if err != nil { 419 log.Errorf("ws connection error: %v", err) 420 return 421 } 422 423 _, isHiddenService := r.Context().Value(ctxListener).(onionListener) 424 if isHiddenService { 425 log.Infof("Hidden service websocket connection starting from %v", r.RemoteAddr) // should be 127.0.0.1 426 } 427 // TODO: give isHiddenService to websocketHandler, possibly with a 428 // special dex.IPKey rather than the one from r.RemoteAddr 429 430 // http.Server.Shutdown waits for connections to complete (such as this 431 // http.HandlerFunc), but not the long running upgraded websocket 432 // connections. We must wait on each websocketHandler to return in 433 // response to disconnectClients. 434 log.Debugf("Starting websocket handler for %s", r.RemoteAddr) // includes source port 435 wg.Add(1) 436 go func() { 437 defer wg.Done() 438 s.websocketHandler(ctx, wsConn, ip) 439 }() 440 }) 441 442 httpServer := &http.Server{ 443 Handler: mux, 444 ReadTimeout: rpcTimeoutSeconds * time.Second, // slow requests should not hold connections opened 445 WriteTimeout: rpcTimeoutSeconds * time.Second, // hung responses must die 446 BaseContext: func(l net.Listener) context.Context { 447 return context.WithValue(ctx, ctxListener, l) // the actual listener is not really useful, maybe drop it 448 }, 449 } 450 451 // Start serving. 452 for _, listener := range s.listeners { 453 wg.Add(1) 454 go func(listener net.Listener) { 455 log.Infof("Server listening on %s", listener.Addr()) 456 err := httpServer.Serve(listener) 457 if !errors.Is(err, http.ErrServerClosed) { 458 log.Warnf("unexpected (http.Server).Serve error: %v", err) 459 } 460 log.Debugf("RPC listener done for %s", listener.Addr()) 461 wg.Done() 462 }(listener) 463 } 464 465 // Run a periodic routine to keep the ipHTTPRateLimiter map clean. 466 go func() { 467 ticker := time.NewTicker(time.Minute * 5) 468 defer ticker.Stop() 469 for { 470 select { 471 case <-ticker.C: 472 rateLimiterMtx.Lock() 473 for ip, limiter := range ipHTTPRateLimiter { 474 if time.Since(limiter.lastHit) > time.Minute { 475 delete(ipHTTPRateLimiter, ip) 476 } 477 } 478 rateLimiterMtx.Unlock() 479 case <-ctx.Done(): 480 return 481 } 482 } 483 }() 484 485 <-ctx.Done() 486 487 // Shutdown the server. This stops all listeners and waits for connections. 488 log.Infof("Server shutting down...") 489 ctxTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) 490 defer cancel() 491 err := httpServer.Shutdown(ctxTimeout) 492 if err != nil { 493 log.Warnf("http.Server.Shutdown: %v", err) 494 } 495 496 // Stop and disconnect websocket clients. 497 s.disconnectClients() 498 499 // When the http.Server is shut down, all websocket clients are gone, and 500 // the listener goroutines have returned, the server is shut down. 501 wg.Wait() 502 log.Infof("Server shutdown complete") 503 } 504 505 func (s *Server) Mux() *chi.Mux { 506 return s.mux 507 } 508 509 // Check if the IP address is quarantined. 510 func (s *Server) isQuarantined(ip dex.IPKey) bool { 511 s.banMtx.RLock() 512 banTime, banned := s.quarantine[ip] 513 s.banMtx.RUnlock() 514 if banned { 515 // See if the ban has expired. 516 if time.Now().After(banTime) { 517 s.banMtx.Lock() 518 delete(s.quarantine, ip) 519 s.banMtx.Unlock() 520 banned = false 521 } 522 } 523 return banned 524 } 525 526 // Quarantine the specified IP address. 527 func (s *Server) banish(ip dex.IPKey) { 528 s.banMtx.Lock() 529 defer s.banMtx.Unlock() 530 s.quarantine[ip] = time.Now().Add(banishTime) 531 } 532 533 // wsLimiter gets any existing routeLimiter for an IP incrementing the 534 // connection count for the address, or creates a new one. The caller should use 535 // wsLimiterDone after the connection that uses routeLimiter is closed. Loopback 536 // addresses always get a new unshared limiter. NewIPKey should be used to 537 // create an IPKey with interface bits masked out. This is not perfect with 538 // respect to remote IPv6 hosts assigned multiple subnets (up to 16 bits worth). 539 // Disable IPv6 if this is not acceptable. 540 func (s *Server) wsLimiter(ip dex.IPKey) *routeLimiter { 541 // If the ip is a loopback address, this likely indicates a hidden service 542 // or misconfigured reverse proxy, and it is undesirable for many such 543 // connections to share a common limiter. To avoid this, return a new 544 // untracked limiter for such clients. 545 if ip.IsLoopback() { 546 return newRouteLimiter() 547 } 548 549 s.wsLimiterMtx.Lock() 550 defer s.wsLimiterMtx.Unlock() 551 prefix := ip.PrefixV6() 552 if prefix != nil { // not ipv4 553 if n := s.v6Prefixes[*prefix]; n > 0 { 554 log.Infof("Detected %d active IPv6 connections with same prefix %v", n, prefix) 555 // Consider: Use a prefix-aggregated limiter when n > threshold. If 556 // we want to get really sophisticated, we may look into a tiered 557 // aggregation algorithm. https://serverfault.com/a/919324/190378 558 // 559 // ip = *prefix 560 } 561 } 562 563 if l := s.wsLimiters[ip]; l != nil { 564 if l.conns >= rpcMaxConnsPerIP { 565 return nil 566 } 567 l.conns++ 568 if prefix != nil { 569 s.v6Prefixes[*prefix]++ 570 } 571 if l.cleaner != nil { // l.conns was zero 572 log.Debugf("Restoring active rate limiter for %v", ip) 573 // Even if the timer already fired, we won the race to the lock and 574 // incremented conns so the cleaner func will be a no-op. 575 l.cleaner.Stop() // false means timer fired already 576 l.cleaner = nil 577 } 578 return l.routeLimiter 579 } 580 581 limiter := newRouteLimiter() 582 s.wsLimiters[ip] = &ipWsLimiter{ 583 conns: 1, 584 routeLimiter: limiter, 585 } 586 if prefix != nil { 587 s.v6Prefixes[*prefix]++ 588 } 589 return limiter 590 } 591 592 // wsLimiterDone decrements the connection count for the IP address' 593 // routeLimiter, and deletes it entirely if there are no remaining connections 594 // from this address. 595 func (s *Server) wsLimiterDone(ip dex.IPKey) { 596 s.wsLimiterMtx.Lock() 597 defer s.wsLimiterMtx.Unlock() 598 599 if prefix := ip.PrefixV6(); prefix != nil { 600 switch s.v6Prefixes[*prefix] { 601 case 0: 602 case 1: 603 delete(s.v6Prefixes, *prefix) 604 default: 605 s.v6Prefixes[*prefix]-- 606 } 607 } 608 609 wsLimiter := s.wsLimiters[ip] 610 if wsLimiter == nil { 611 return // untracked limiter (i.e. loopback) 612 // If using prefix-aggregated limiters, we'd check for one here. 613 } 614 615 wsLimiter.conns-- 616 if wsLimiter.conns < 1 { 617 // Start a cleanup timer. 618 wsLimiter.cleaner = time.AfterFunc(time.Minute, func() { 619 s.wsLimiterMtx.Lock() 620 defer s.wsLimiterMtx.Unlock() 621 if wsLimiter.conns < 1 { 622 log.Debugf("Forgetting rate limiter for %v", ip) 623 delete(s.wsLimiters, ip) 624 } // else lost the race to the mutex, don't remove 625 }) 626 } 627 } 628 629 // websocketHandler handles a new websocket client by creating a new wsClient, 630 // starting it, and blocking until the connection closes. This method should be 631 // run as a goroutine. 632 func (s *Server) websocketHandler(ctx context.Context, conn ws.Connection, ip dex.IPKey) { 633 addr := ip.String() 634 log.Tracef("New websocket client %s", addr) 635 636 // Create a new websocket client to handle the new websocket connection 637 // and wait for it to shutdown. Once it has shutdown (and hence 638 // disconnected), remove it. 639 dataRoutesMeter := func() (int, error) { return s.meterIP(ip) } // includes global limiter and may be disabled 640 wsLimiter := s.wsLimiter(ip) 641 if wsLimiter == nil { // too many active ws conns from this IP 642 log.Warnf("Too many websocket connections from %v", ip) 643 return 644 } 645 defer s.wsLimiterDone(ip) 646 client := s.newWSLink(addr, conn, wsLimiter, dataRoutesMeter) 647 648 cm, err := s.addClient(ctx, client) 649 if err != nil { 650 log.Errorf("Failed to add client %s", addr) 651 return 652 } 653 defer s.removeClient(client.id) 654 655 // The connection remains until the connection is lost or the link's 656 // disconnect method is called (e.g. via disconnectClients). 657 cm.Wait() 658 659 // If the ban flag is set, quarantine the client's IP address. 660 if client.ban { 661 s.banish(ip) 662 } 663 log.Tracef("Disconnected websocket client %s", addr) 664 } 665 666 // Broadcast sends a message to all connected clients. The message should be a 667 // notification. See msgjson.NewNotification. 668 func (s *Server) Broadcast(msg *msgjson.Message) { 669 // Marshal and send the bytes to avoid multiple marshals when sending. 670 b, err := json.Marshal(msg) 671 if err != nil { 672 log.Errorf("unable to marshal broadcast Message: %v", err) 673 return 674 } 675 676 s.clientMtx.RLock() 677 defer s.clientMtx.RUnlock() 678 679 log.Infof("Broadcasting %s for route %s to %d clients...", msg.Type, msg.Route, len(s.clients)) 680 if log.Level() <= dex.LevelTrace { // don't marshal unless needed 681 log.Tracef("Broadcast: %q", msg.String()) 682 } 683 684 for id, cl := range s.clients { 685 if err := cl.SendRaw(b); err != nil { 686 log.Debugf("Send to client %d at %s failed: %v", id, cl.Addr(), err) 687 cl.Disconnect() // triggers return of websocketHandler, and removeClient 688 } 689 } 690 } 691 692 // EnableDataAPI enables or disables the HTTP data API endpoints. 693 func (s *Server) EnableDataAPI(yes bool) { 694 if yes { 695 atomic.StoreUint32(&s.dataEnabled, 1) 696 } else { 697 atomic.StoreUint32(&s.dataEnabled, 0) 698 } 699 } 700 701 // disconnectClients calls disconnect on each wsLink, but does not remove it 702 // from the Server's client map. 703 func (s *Server) disconnectClients() { 704 s.clientMtx.Lock() 705 for _, link := range s.clients { 706 link.Disconnect() 707 } 708 s.clientMtx.Unlock() 709 } 710 711 // addClient assigns the client an ID, adds it to the map, and attempts to 712 // connect. 713 func (s *Server) addClient(ctx context.Context, client *wsLink) (*dex.ConnectionMaster, error) { 714 s.clientMtx.Lock() 715 defer s.clientMtx.Unlock() 716 cm := dex.NewConnectionMaster(client) 717 if err := cm.ConnectOnce(ctx); err != nil { 718 return nil, err 719 } 720 client.id = s.counter 721 s.counter++ 722 s.clients[client.id] = client 723 return cm, nil 724 } 725 726 // Remove the client from the map. 727 func (s *Server) removeClient(id uint64) { 728 s.clientMtx.Lock() 729 delete(s.clients, id) 730 s.clientMtx.Unlock() 731 } 732 733 // Get the number of active clients. 734 func (s *Server) clientCount() uint64 { 735 s.clientMtx.RLock() 736 defer s.clientMtx.RUnlock() 737 return uint64(len(s.clients)) 738 } 739 740 // Get the number of websocket connections for a given IP, excluding loopback. 741 func (s *Server) ipConnCount(ip dex.IPKey) int64 { 742 s.wsLimiterMtx.Lock() 743 defer s.wsLimiterMtx.Unlock() 744 wsl := s.wsLimiters[ip] 745 if wsl == nil { 746 return 0 747 } 748 return wsl.conns 749 } 750 751 // genCertPair generates a key/cert pair to the paths provided. 752 func genCertPair(certFile, keyFile string, altDNSNames []string) error { 753 log.Infof("Generating TLS certificates...") 754 755 org := "dcrdex autogenerated cert" 756 validUntil := time.Now().Add(10 * 365 * 24 * time.Hour) 757 cert, key, err := certgen.NewTLSCertPair(elliptic.P521(), org, 758 validUntil, altDNSNames) 759 if err != nil { 760 return err 761 } 762 763 // Write cert and key files. 764 if err = os.WriteFile(certFile, cert, 0644); err != nil { 765 return err 766 } 767 if err = os.WriteFile(keyFile, key, 0600); err != nil { 768 os.Remove(certFile) 769 return err 770 } 771 772 log.Infof("Done generating TLS certificates") 773 return nil 774 } 775 776 // parseListeners splits the list of listen addresses passed in addrs into 777 // IPv4 and IPv6 slices and returns them. This allows easy creation of the 778 // listeners on the correct interface "tcp4" and "tcp6". It also properly 779 // detects addresses which apply to "all interfaces" and adds the address to 780 // both slices. 781 func parseListeners(addrs []string) ([]string, []string, bool, error) { 782 ipv4ListenAddrs := make([]string, 0, len(addrs)) 783 ipv6ListenAddrs := make([]string, 0, len(addrs)) 784 haveWildcard := false 785 786 for _, addr := range addrs { 787 host, _, err := net.SplitHostPort(addr) 788 if err != nil { 789 // Shouldn't happen due to already being normalized. 790 return nil, nil, false, err 791 } 792 793 // Empty host is both IPv4 and IPv6. 794 if host == "" { 795 ipv4ListenAddrs = append(ipv4ListenAddrs, addr) 796 ipv6ListenAddrs = append(ipv6ListenAddrs, addr) 797 haveWildcard = true 798 continue 799 } 800 801 // Strip IPv6 zone id if present since net.ParseIP does not 802 // handle it. 803 zoneIndex := strings.LastIndex(host, "%") 804 if zoneIndex > 0 { 805 host = host[:zoneIndex] 806 } 807 808 // Parse the IP. 809 ip := net.ParseIP(host) 810 if ip == nil { 811 return nil, nil, false, fmt.Errorf("'%s' is not a valid IP address", host) 812 } 813 814 // To4 returns nil when the IP is not an IPv4 address, so use 815 // this determine the address type. 816 if ip.To4() == nil { 817 ipv6ListenAddrs = append(ipv6ListenAddrs, addr) 818 } else { 819 ipv4ListenAddrs = append(ipv4ListenAddrs, addr) 820 } 821 } 822 return ipv4ListenAddrs, ipv6ListenAddrs, haveWildcard, nil 823 } 824 825 // NewRouteHandler creates a HandlerFunc for a route. Middleware should have 826 // already processed the request and added the request struct to the Context. 827 func (s *Server) NewRouteHandler(route string) func(w http.ResponseWriter, r *http.Request) { 828 handler := s.httpRoutes[route] 829 if handler == nil { 830 panic("no known handler for " + route) 831 } 832 return func(w http.ResponseWriter, r *http.Request) { 833 resp, err := handler(r.Context().Value(CtxThing)) 834 if err != nil { 835 writeJSONWithStatus(w, map[string]string{"error": err.Error()}, http.StatusBadRequest) 836 return 837 } 838 writeJSONWithStatus(w, resp, http.StatusOK) 839 } 840 } 841 842 // writeJSONWithStatus writes the JSON response with the specified HTTP response 843 // code. 844 func writeJSONWithStatus(w http.ResponseWriter, thing any, code int) { 845 w.Header().Set("Content-Type", "application/json; charset=utf-8") 846 b, err := json.Marshal(thing) 847 if err != nil { 848 w.WriteHeader(http.StatusInternalServerError) 849 log.Errorf("JSON encode error: %v", err) 850 return 851 } 852 w.WriteHeader(code) 853 _, err = w.Write(append(b, byte('\n'))) 854 if err != nil { 855 log.Errorf("Write error: %v", err) 856 } 857 }