github.com/minio/minio@v0.0.0-20240328213742-3f72439b8a27/internal/grid/connection.go (about) 1 // Copyright (c) 2015-2023 MinIO, Inc. 2 // 3 // This file is part of MinIO Object Storage stack 4 // 5 // This program is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Affero General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Affero General Public License for more details. 14 // 15 // You should have received a copy of the GNU Affero General Public License 16 // along with this program. If not, see <http://www.gnu.org/licenses/>. 17 18 package grid 19 20 import ( 21 "bytes" 22 "context" 23 "crypto/tls" 24 "encoding/binary" 25 "errors" 26 "fmt" 27 "io" 28 "math" 29 "math/rand" 30 "net" 31 "net/http" 32 "runtime/debug" 33 "strings" 34 "sync" 35 "sync/atomic" 36 "time" 37 38 "github.com/gobwas/ws" 39 "github.com/gobwas/ws/wsutil" 40 "github.com/google/uuid" 41 "github.com/minio/madmin-go/v3" 42 xioutil "github.com/minio/minio/internal/ioutil" 43 "github.com/minio/minio/internal/logger" 44 "github.com/minio/minio/internal/pubsub" 45 "github.com/puzpuzpuz/xsync/v3" 46 "github.com/tinylib/msgp/msgp" 47 "github.com/zeebo/xxh3" 48 ) 49 50 // A Connection is a remote connection. 51 // There is no distinction externally whether the connection was initiated from 52 // this server or from the remote. 53 type Connection struct { 54 // NextID is the next ID that can be used (atomic). 55 NextID uint64 56 57 // LastPong is last pong time (atomic) 58 // Only valid when StateConnected. 59 LastPong int64 60 61 // State of the connection (atomic) 62 state State 63 64 // Non-atomic 65 Remote string 66 Local string 67 68 // ID of this connection instance. 69 id uuid.UUID 70 71 // Remote uuid, if we have been connected. 72 remoteID *uuid.UUID 73 reconnectMu sync.Mutex 74 75 // Context for the server. 76 ctx context.Context 77 78 // Active mux connections. 79 outgoing *xsync.MapOf[uint64, *muxClient] 80 81 // Incoming streams 82 inStream *xsync.MapOf[uint64, *muxServer] 83 84 // outQueue is the output queue 85 outQueue chan []byte 86 87 // Client or serverside. 88 side ws.State 89 90 // Transport for outgoing connections. 91 dialer ContextDialer 92 header http.Header 93 94 handleMsgWg sync.WaitGroup 95 96 // connChange will be signaled whenever State has been updated, or at regular intervals. 97 // Holding the lock allows safe reads of State, and guarantees that changes will be detected. 98 connChange *sync.Cond 99 handlers *handlers 100 101 remote *RemoteClient 102 auth AuthFn 103 clientPingInterval time.Duration 104 connPingInterval time.Duration 105 tlsConfig *tls.Config 106 blockConnect chan struct{} 107 108 incomingBytes func(n int64) // Record incoming bytes. 109 outgoingBytes func(n int64) // Record outgoing bytes. 110 trace *tracer // tracer for this connection. 111 baseFlags Flags 112 113 // For testing only 114 debugInConn net.Conn 115 debugOutConn net.Conn 116 addDeadline time.Duration 117 connMu sync.Mutex 118 } 119 120 // Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID. 121 type Subroute struct { 122 *Connection 123 trace *tracer 124 route string 125 subID subHandlerID 126 } 127 128 // String returns a string representation of the connection. 129 func (c *Connection) String() string { 130 return fmt.Sprintf("%s->%s", c.Local, c.Remote) 131 } 132 133 // StringReverse returns a string representation of the reverse connection. 134 func (c *Connection) StringReverse() string { 135 return fmt.Sprintf("%s->%s", c.Remote, c.Local) 136 } 137 138 // State is a connection state. 139 type State uint32 140 141 // MANUAL go:generate stringer -type=State -output=state_string.go -trimprefix=State $GOFILE 142 143 const ( 144 // StateUnconnected is the initial state of a connection. 145 // When the first message is sent it will attempt to connect. 146 StateUnconnected = iota 147 148 // StateConnecting is the state from StateUnconnected while the connection is attempted to be established. 149 // After this connection will be StateConnected or StateConnectionError. 150 StateConnecting 151 152 // StateConnected is the state when the connection has been established and is considered stable. 153 // If the connection is lost, state will switch to StateConnecting. 154 StateConnected 155 156 // StateConnectionError is the state once a connection attempt has been made, and it failed. 157 // The connection will remain in this stat until the connection has been successfully re-established. 158 StateConnectionError 159 160 // StateShutdown is the state when the server has been shut down. 161 // This will not be used under normal operation. 162 StateShutdown 163 164 // MaxDeadline is the maximum deadline allowed, 165 // Approx 49 days. 166 MaxDeadline = time.Duration(math.MaxUint32) * time.Millisecond 167 ) 168 169 // ContextDialer is a dialer that can be used to dial a remote. 170 type ContextDialer func(ctx context.Context, network, address string) (net.Conn, error) 171 172 // DialContext implements the Dialer interface. 173 func (c ContextDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 174 return c(ctx, network, address) 175 } 176 177 const ( 178 defaultOutQueue = 65535 // kind of close to max open fds per user 179 readBufferSize = 32 << 10 // 32 KiB is the most optimal on Linux 180 writeBufferSize = 32 << 10 // 32 KiB is the most optimal on Linux 181 defaultDialTimeout = 2 * time.Second 182 connPingInterval = 10 * time.Second 183 connWriteTimeout = 3 * time.Second 184 ) 185 186 type connectionParams struct { 187 ctx context.Context 188 id uuid.UUID 189 local, remote string 190 dial ContextDialer 191 handlers *handlers 192 auth AuthFn 193 tlsConfig *tls.Config 194 incomingBytes func(n int64) // Record incoming bytes. 195 outgoingBytes func(n int64) // Record outgoing bytes. 196 publisher *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType] 197 198 blockConnect chan struct{} 199 } 200 201 // newConnection will create an unconnected connection to a remote. 202 func newConnection(o connectionParams) *Connection { 203 c := &Connection{ 204 state: StateUnconnected, 205 Remote: o.remote, 206 Local: o.local, 207 id: o.id, 208 ctx: o.ctx, 209 outgoing: xsync.NewMapOfPresized[uint64, *muxClient](1000), 210 inStream: xsync.NewMapOfPresized[uint64, *muxServer](1000), 211 outQueue: make(chan []byte, defaultOutQueue), 212 dialer: o.dial, 213 side: ws.StateServerSide, 214 connChange: &sync.Cond{L: &sync.Mutex{}}, 215 handlers: o.handlers, 216 auth: o.auth, 217 header: make(http.Header, 1), 218 remote: &RemoteClient{Name: o.remote}, 219 clientPingInterval: clientPingInterval, 220 connPingInterval: connPingInterval, 221 tlsConfig: o.tlsConfig, 222 incomingBytes: o.incomingBytes, 223 outgoingBytes: o.outgoingBytes, 224 } 225 if debugPrint { 226 // Random Mux ID 227 c.NextID = rand.Uint64() 228 } 229 if !strings.HasPrefix(o.remote, "https://") && !strings.HasPrefix(o.remote, "wss://") { 230 c.baseFlags |= FlagCRCxxh3 231 } 232 if !strings.HasPrefix(o.local, "https://") && !strings.HasPrefix(o.local, "wss://") { 233 c.baseFlags |= FlagCRCxxh3 234 } 235 if o.publisher != nil { 236 c.traceRequests(o.publisher) 237 } 238 if o.local == o.remote { 239 panic("equal hosts") 240 } 241 if c.shouldConnect() { 242 c.side = ws.StateClientSide 243 244 go func() { 245 if o.blockConnect != nil { 246 <-o.blockConnect 247 } 248 c.connect() 249 }() 250 } 251 if debugPrint { 252 fmt.Println(c.Local, "->", c.Remote, "Should local connect:", c.shouldConnect(), "side:", c.side) 253 } 254 if debugReqs { 255 fmt.Println("Created connection", c.String()) 256 } 257 return c 258 } 259 260 // Subroute returns a static subroute for the connection. 261 func (c *Connection) Subroute(s string) *Subroute { 262 if c == nil { 263 return nil 264 } 265 return &Subroute{ 266 Connection: c, 267 route: s, 268 subID: makeSubHandlerID(0, s), 269 trace: c.trace.subroute(s), 270 } 271 } 272 273 // Subroute adds a subroute to the subroute. 274 // The subroutes are combined with '/'. 275 func (c *Subroute) Subroute(s string) *Subroute { 276 route := strings.Join([]string{c.route, s}, "/") 277 return &Subroute{ 278 Connection: c.Connection, 279 route: route, 280 subID: makeSubHandlerID(0, route), 281 trace: c.trace.subroute(route), 282 } 283 } 284 285 // newMuxClient returns a mux client for manual use. 286 func (c *Connection) newMuxClient(ctx context.Context) (*muxClient, error) { 287 client := newMuxClient(ctx, atomic.AddUint64(&c.NextID, 1), c) 288 if dl, ok := ctx.Deadline(); ok { 289 client.deadline = getDeadline(time.Until(dl)) 290 if client.deadline == 0 { 291 client.cancelFn(context.DeadlineExceeded) 292 return nil, context.DeadlineExceeded 293 } 294 } 295 for { 296 // Handle the extremely unlikely scenario that we wrapped. 297 if _, loaded := c.outgoing.LoadOrStore(client.MuxID, client); client.MuxID != 0 && !loaded { 298 if debugReqs { 299 _, found := c.outgoing.Load(client.MuxID) 300 fmt.Println(client.MuxID, c.String(), "Connection.newMuxClient: RELOADED MUX. loaded:", loaded, "found:", found) 301 } 302 return client, nil 303 } 304 client.MuxID = atomic.AddUint64(&c.NextID, 1) 305 } 306 } 307 308 // newMuxClient returns a mux client for manual use. 309 func (c *Subroute) newMuxClient(ctx context.Context) (*muxClient, error) { 310 cl, err := c.Connection.newMuxClient(ctx) 311 if err != nil { 312 return nil, err 313 } 314 cl.subroute = &c.subID 315 return cl, nil 316 } 317 318 // Request allows to do a single remote request. 319 // 'req' will not be used after the call and caller can reuse. 320 // If no deadline is set on ctx, a 1-minute deadline will be added. 321 func (c *Connection) Request(ctx context.Context, h HandlerID, req []byte) ([]byte, error) { 322 if !h.valid() { 323 return nil, ErrUnknownHandler 324 } 325 if c.State() != StateConnected { 326 return nil, ErrDisconnected 327 } 328 // Create mux client and call. 329 client, err := c.newMuxClient(ctx) 330 if err != nil { 331 return nil, err 332 } 333 defer func() { 334 if debugReqs { 335 _, ok := c.outgoing.Load(client.MuxID) 336 fmt.Println(client.MuxID, c.String(), "Connection.Request: DELETING MUX. Exists:", ok) 337 } 338 client.cancelFn(context.Canceled) 339 c.outgoing.Delete(client.MuxID) 340 }() 341 return client.traceRoundtrip(ctx, c.trace, h, req) 342 } 343 344 // Request allows to do a single remote request. 345 // 'req' will not be used after the call and caller can reuse. 346 // If no deadline is set on ctx, a 1-minute deadline will be added. 347 func (c *Subroute) Request(ctx context.Context, h HandlerID, req []byte) ([]byte, error) { 348 if !h.valid() { 349 return nil, ErrUnknownHandler 350 } 351 if c.State() != StateConnected { 352 return nil, ErrDisconnected 353 } 354 // Create mux client and call. 355 client, err := c.newMuxClient(ctx) 356 if err != nil { 357 return nil, err 358 } 359 client.subroute = &c.subID 360 defer func() { 361 if debugReqs { 362 fmt.Println(client.MuxID, c.String(), "Subroute.Request: DELETING MUX") 363 } 364 client.cancelFn(context.Canceled) 365 c.outgoing.Delete(client.MuxID) 366 }() 367 return client.traceRoundtrip(ctx, c.trace, h, req) 368 } 369 370 // NewStream creates a new stream. 371 // Initial payload can be reused by the caller. 372 func (c *Connection) NewStream(ctx context.Context, h HandlerID, payload []byte) (st *Stream, err error) { 373 if !h.valid() { 374 return nil, ErrUnknownHandler 375 } 376 if c.State() != StateConnected { 377 return nil, ErrDisconnected 378 } 379 handler := c.handlers.streams[h] 380 if handler == nil { 381 return nil, ErrUnknownHandler 382 } 383 384 var requests chan []byte 385 var responses chan Response 386 if handler.InCapacity > 0 { 387 requests = make(chan []byte, handler.InCapacity) 388 } 389 if handler.OutCapacity > 0 { 390 responses = make(chan Response, handler.OutCapacity) 391 } else { 392 responses = make(chan Response, 1) 393 } 394 395 cl, err := c.newMuxClient(ctx) 396 if err != nil { 397 return nil, err 398 } 399 400 return cl.RequestStream(h, payload, requests, responses) 401 } 402 403 // NewStream creates a new stream. 404 // Initial payload can be reused by the caller. 405 func (c *Subroute) NewStream(ctx context.Context, h HandlerID, payload []byte) (st *Stream, err error) { 406 if !h.valid() { 407 return nil, ErrUnknownHandler 408 } 409 if c.State() != StateConnected { 410 return nil, ErrDisconnected 411 } 412 handler := c.handlers.subStreams[makeZeroSubHandlerID(h)] 413 if handler == nil { 414 if debugPrint { 415 fmt.Println("want", makeZeroSubHandlerID(h), c.route, "got", c.handlers.subStreams) 416 } 417 return nil, ErrUnknownHandler 418 } 419 420 var requests chan []byte 421 var responses chan Response 422 if handler.InCapacity > 0 { 423 requests = make(chan []byte, handler.InCapacity) 424 } 425 if handler.OutCapacity > 0 { 426 responses = make(chan Response, handler.OutCapacity) 427 } else { 428 responses = make(chan Response, 1) 429 } 430 431 cl, err := c.newMuxClient(ctx) 432 if err != nil { 433 return nil, err 434 } 435 cl.subroute = &c.subID 436 437 return cl.RequestStream(h, payload, requests, responses) 438 } 439 440 // WaitForConnect will block until a connection has been established or 441 // the context is canceled, in which case the context error is returned. 442 func (c *Connection) WaitForConnect(ctx context.Context) error { 443 if debugPrint { 444 fmt.Println(c.Local, "->", c.Remote, "WaitForConnect") 445 defer fmt.Println(c.Local, "->", c.Remote, "WaitForConnect done") 446 } 447 c.connChange.L.Lock() 448 if atomic.LoadUint32((*uint32)(&c.state)) == StateConnected { 449 c.connChange.L.Unlock() 450 // Happy path. 451 return nil 452 } 453 ctx, cancel := context.WithCancel(ctx) 454 defer cancel() 455 changed := make(chan State, 1) 456 go func() { 457 defer xioutil.SafeClose(changed) 458 for { 459 c.connChange.Wait() 460 newState := c.State() 461 select { 462 case changed <- newState: 463 if newState == StateConnected || newState == StateShutdown { 464 c.connChange.L.Unlock() 465 return 466 } 467 case <-ctx.Done(): 468 c.connChange.L.Unlock() 469 return 470 } 471 } 472 }() 473 474 for { 475 select { 476 case <-ctx.Done(): 477 return context.Cause(ctx) 478 case newState := <-changed: 479 if newState == StateConnected { 480 return nil 481 } 482 } 483 } 484 } 485 486 /* 487 var ErrDone = errors.New("done for now") 488 489 var ErrRemoteRestart = errors.New("remote restarted") 490 491 492 // Stateless connects to the remote handler and return all packets sent back. 493 // If the remote is restarted will return ErrRemoteRestart. 494 // If nil will be returned remote call sent EOF or ErrDone is returned by the callback. 495 // If ErrDone is returned on cb nil will be returned. 496 func (c *Connection) Stateless(ctx context.Context, h HandlerID, req []byte, cb func([]byte) error) error { 497 client, err := c.newMuxClient(ctx) 498 if err != nil { 499 return err 500 } 501 defer c.outgoing.Delete(client.MuxID) 502 resp := make(chan Response, 10) 503 client.RequestStateless(h, req, resp) 504 505 for r := range resp { 506 if r.Err != nil { 507 return r.Err 508 } 509 if len(r.Msg) > 0 { 510 err := cb(r.Msg) 511 if err != nil { 512 if errors.Is(err, ErrDone) { 513 break 514 } 515 return err 516 } 517 } 518 } 519 return nil 520 } 521 */ 522 523 // shouldConnect returns a deterministic bool whether the local should initiate the connection. 524 // It should be 50% chance of any host initiating the connection. 525 func (c *Connection) shouldConnect() bool { 526 // The remote should have the opposite result. 527 h0 := xxh3.HashString(c.Local + c.Remote) 528 h1 := xxh3.HashString(c.Remote + c.Local) 529 if h0 == h1 { 530 return c.Local < c.Remote 531 } 532 return h0 < h1 533 } 534 535 func (c *Connection) send(ctx context.Context, msg []byte) error { 536 select { 537 case <-ctx.Done(): 538 // Returning error here is too noisy. 539 return nil 540 case c.outQueue <- msg: 541 return nil 542 } 543 } 544 545 // queueMsg queues a message, with an optional payload. 546 // sender should not reference msg.Payload 547 func (c *Connection) queueMsg(msg message, payload sender) error { 548 // Add baseflags. 549 msg.Flags.Set(c.baseFlags) 550 // This cannot encode subroute. 551 msg.Flags.Clear(FlagSubroute) 552 if payload != nil { 553 if cap(msg.Payload) < payload.Msgsize() { 554 old := msg.Payload 555 msg.Payload = GetByteBuffer()[:0] 556 PutByteBuffer(old) 557 } 558 var err error 559 msg.Payload, err = payload.MarshalMsg(msg.Payload[:0]) 560 msg.Op = payload.Op() 561 if err != nil { 562 return err 563 } 564 } 565 defer PutByteBuffer(msg.Payload) 566 dst := GetByteBuffer()[:0] 567 dst, err := msg.MarshalMsg(dst) 568 if err != nil { 569 return err 570 } 571 if msg.Flags&FlagCRCxxh3 != 0 { 572 h := xxh3.Hash(dst) 573 dst = binary.LittleEndian.AppendUint32(dst, uint32(h)) 574 } 575 return c.send(c.ctx, dst) 576 } 577 578 // sendMsg will send 579 func (c *Connection) sendMsg(conn net.Conn, msg message, payload msgp.MarshalSizer) error { 580 if payload != nil { 581 if cap(msg.Payload) < payload.Msgsize() { 582 PutByteBuffer(msg.Payload) 583 msg.Payload = GetByteBuffer()[:0] 584 } 585 var err error 586 msg.Payload, err = payload.MarshalMsg(msg.Payload) 587 if err != nil { 588 return err 589 } 590 defer PutByteBuffer(msg.Payload) 591 } 592 dst := GetByteBuffer()[:0] 593 dst, err := msg.MarshalMsg(dst) 594 if err != nil { 595 return err 596 } 597 if msg.Flags&FlagCRCxxh3 != 0 { 598 h := xxh3.Hash(dst) 599 dst = binary.LittleEndian.AppendUint32(dst, uint32(h)) 600 } 601 if debugPrint { 602 fmt.Println(c.Local, "sendMsg: Sending", msg.Op, "as", len(dst), "bytes") 603 } 604 if c.outgoingBytes != nil { 605 c.outgoingBytes(int64(len(dst))) 606 } 607 err = conn.SetWriteDeadline(time.Now().Add(connWriteTimeout)) 608 if err != nil { 609 return err 610 } 611 return wsutil.WriteMessage(conn, c.side, ws.OpBinary, dst) 612 } 613 614 func (c *Connection) connect() { 615 c.updateState(StateConnecting) 616 rng := rand.New(rand.NewSource(time.Now().UnixNano())) 617 // Runs until the server is shut down. 618 for { 619 if c.State() == StateShutdown { 620 return 621 } 622 toDial := strings.Replace(c.Remote, "http://", "ws://", 1) 623 toDial = strings.Replace(toDial, "https://", "wss://", 1) 624 toDial += RoutePath 625 626 dialer := ws.DefaultDialer 627 dialer.ReadBufferSize = readBufferSize 628 dialer.WriteBufferSize = writeBufferSize 629 dialer.Timeout = defaultDialTimeout 630 if c.dialer != nil { 631 dialer.NetDial = c.dialer.DialContext 632 } 633 if c.header == nil { 634 c.header = make(http.Header, 2) 635 } 636 c.header.Set("Authorization", "Bearer "+c.auth("")) 637 c.header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339)) 638 639 if len(c.header) > 0 { 640 dialer.Header = ws.HandshakeHeaderHTTP(c.header) 641 } 642 dialer.TLSConfig = c.tlsConfig 643 dialStarted := time.Now() 644 if debugPrint { 645 fmt.Println(c.Local, "Connecting to ", toDial) 646 } 647 conn, br, _, err := dialer.Dial(c.ctx, toDial) 648 if br != nil { 649 ws.PutReader(br) 650 } 651 c.connMu.Lock() 652 c.debugOutConn = conn 653 c.connMu.Unlock() 654 retry := func(err error) { 655 if debugPrint { 656 fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, toDial, err) 657 } 658 sleep := defaultDialTimeout + time.Duration(rng.Int63n(int64(defaultDialTimeout))) 659 next := dialStarted.Add(sleep / 2) 660 sleep = time.Until(next).Round(time.Millisecond) 661 if sleep < 0 { 662 sleep = 0 663 } 664 gotState := c.State() 665 if gotState == StateShutdown { 666 return 667 } 668 if gotState != StateConnecting { 669 // Don't print error on first attempt, 670 // and after that only once per hour. 671 logger.LogOnceIf(c.ctx, fmt.Errorf("grid: %s connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, toDial, err, err, sleep, gotState), toDial) 672 } 673 c.updateState(StateConnectionError) 674 time.Sleep(sleep) 675 } 676 if err != nil { 677 retry(err) 678 continue 679 } 680 // Send connect message. 681 m := message{ 682 Op: OpConnect, 683 } 684 req := connectReq{ 685 Host: c.Local, 686 ID: c.id, 687 } 688 err = c.sendMsg(conn, m, &req) 689 if err != nil { 690 retry(err) 691 continue 692 } 693 // Wait for response 694 var r connectResp 695 err = c.receive(conn, &r) 696 if err != nil { 697 if debugPrint { 698 fmt.Println(c.Local, "receive err:", err, "side:", c.side) 699 } 700 retry(err) 701 continue 702 } 703 if debugPrint { 704 fmt.Println(c.Local, "Got connectResp:", r) 705 } 706 if !r.Accepted { 707 retry(fmt.Errorf("connection rejected: %s", r.RejectedReason)) 708 continue 709 } 710 c.reconnectMu.Lock() 711 remoteUUID := uuid.UUID(r.ID) 712 if c.remoteID != nil { 713 c.reconnected() 714 } 715 c.remoteID = &remoteUUID 716 if debugPrint { 717 fmt.Println(c.Local, "Connected Waiting for Messages") 718 } 719 // Handle messages... 720 c.handleMessages(c.ctx, conn) 721 // Reconnect unless we are shutting down (debug only). 722 if c.State() == StateShutdown { 723 conn.Close() 724 return 725 } 726 if debugPrint { 727 fmt.Println(c.Local, "Disconnected. Attempting to reconnect.") 728 } 729 } 730 } 731 732 func (c *Connection) disconnected() { 733 c.outgoing.Range(func(key uint64, client *muxClient) bool { 734 if !client.stateless { 735 client.cancelFn(ErrDisconnected) 736 } 737 return true 738 }) 739 if debugReqs { 740 fmt.Println(c.String(), "Disconnected. Clearing outgoing.") 741 } 742 c.outgoing.Clear() 743 c.inStream.Range(func(key uint64, client *muxServer) bool { 744 client.cancel() 745 return true 746 }) 747 c.inStream.Clear() 748 } 749 750 func (c *Connection) receive(conn net.Conn, r receiver) error { 751 b, op, err := wsutil.ReadData(conn, c.side) 752 if err != nil { 753 return err 754 } 755 if op != ws.OpBinary { 756 return fmt.Errorf("unexpected connect response type %v", op) 757 } 758 if c.incomingBytes != nil { 759 c.incomingBytes(int64(len(b))) 760 } 761 762 var m message 763 _, _, err = m.parse(b) 764 if err != nil { 765 return err 766 } 767 if m.Op != r.Op() { 768 return fmt.Errorf("unexpected response OP, want %v, got %v", r.Op(), m.Op) 769 } 770 _, err = r.UnmarshalMsg(m.Payload) 771 return err 772 } 773 774 func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req connectReq) error { 775 c.connMu.Lock() 776 c.debugInConn = conn 777 c.connMu.Unlock() 778 if c.blockConnect != nil { 779 // Block until we are allowed to connect. 780 <-c.blockConnect 781 } 782 if req.Host != c.Remote { 783 err := fmt.Errorf("expected remote '%s', got '%s'", c.Remote, req.Host) 784 if debugPrint { 785 fmt.Println(err) 786 } 787 return err 788 } 789 if c.shouldConnect() { 790 if debugPrint { 791 fmt.Println("expected to be client side, not server side") 792 } 793 return errors.New("grid: expected to be client side, not server side") 794 } 795 msg := message{ 796 Op: OpConnectResponse, 797 } 798 799 resp := connectResp{ 800 ID: c.id, 801 Accepted: true, 802 } 803 err := c.sendMsg(conn, msg, &resp) 804 if debugPrint { 805 fmt.Printf("grid: Queued Response %+v Side: %v\n", resp, c.side) 806 } 807 if err != nil { 808 return err 809 } 810 // Signal that we are reconnected, update state and handle messages. 811 // Prevent other connections from connecting while we process. 812 c.reconnectMu.Lock() 813 if c.remoteID != nil { 814 c.reconnected() 815 } 816 rid := uuid.UUID(req.ID) 817 c.remoteID = &rid 818 819 // Handle incoming messages until disconnect. 820 c.handleMessages(ctx, conn) 821 return nil 822 } 823 824 // reconnected signals the connection has been reconnected. 825 // It will close all active requests and streams. 826 // caller *must* hold reconnectMu. 827 func (c *Connection) reconnected() { 828 c.updateState(StateConnectionError) 829 // Close all active requests. 830 if debugReqs { 831 fmt.Println(c.String(), "Reconnected. Clearing outgoing.") 832 } 833 c.outgoing.Range(func(key uint64, client *muxClient) bool { 834 client.close() 835 return true 836 }) 837 c.inStream.Range(func(key uint64, value *muxServer) bool { 838 value.close() 839 return true 840 }) 841 842 c.inStream.Clear() 843 c.outgoing.Clear() 844 845 // Wait for existing to exit 846 c.handleMsgWg.Wait() 847 } 848 849 func (c *Connection) updateState(s State) { 850 c.connChange.L.Lock() 851 defer c.connChange.L.Unlock() 852 853 // We may have reads that aren't locked, so update atomically. 854 gotState := atomic.LoadUint32((*uint32)(&c.state)) 855 if gotState == StateShutdown || State(gotState) == s { 856 return 857 } 858 if s == StateConnected { 859 atomic.StoreInt64(&c.LastPong, time.Now().UnixNano()) 860 } 861 atomic.StoreUint32((*uint32)(&c.state), uint32(s)) 862 if debugPrint { 863 fmt.Println(c.Local, "updateState:", gotState, "->", s) 864 } 865 c.connChange.Broadcast() 866 } 867 868 // monitorState will monitor the state of the connection and close the net.Conn if it changes. 869 func (c *Connection) monitorState(conn net.Conn, cancel context.CancelCauseFunc) { 870 c.connChange.L.Lock() 871 defer c.connChange.L.Unlock() 872 for { 873 newState := c.State() 874 if newState != StateConnected { 875 conn.Close() 876 cancel(ErrDisconnected) 877 return 878 } 879 // Unlock and wait for state change. 880 c.connChange.Wait() 881 } 882 } 883 884 // handleMessages will handle incoming messages on conn. 885 // caller *must* hold reconnectMu. 886 func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { 887 c.updateState(StateConnected) 888 ctx, cancel := context.WithCancelCause(ctx) 889 defer cancel(ErrDisconnected) 890 891 // This will ensure that is something asks to disconnect and we are blocked on reads/writes 892 // the connection will be closed and readers/writers will unblock. 893 go c.monitorState(conn, cancel) 894 895 c.handleMsgWg.Add(2) 896 c.reconnectMu.Unlock() 897 898 // Read goroutine 899 go func() { 900 defer func() { 901 if rec := recover(); rec != nil { 902 logger.LogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec)) 903 debug.PrintStack() 904 } 905 c.connChange.L.Lock() 906 if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) { 907 c.connChange.Broadcast() 908 } 909 c.connChange.L.Unlock() 910 conn.Close() 911 c.handleMsgWg.Done() 912 }() 913 914 controlHandler := wsutil.ControlFrameHandler(conn, c.side) 915 wsReader := wsutil.Reader{ 916 Source: conn, 917 State: c.side, 918 CheckUTF8: true, 919 SkipHeaderCheck: false, 920 OnIntermediate: controlHandler, 921 } 922 readDataInto := func(dst []byte, rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, error) { 923 dst = dst[:0] 924 for { 925 hdr, err := wsReader.NextFrame() 926 if err != nil { 927 return nil, err 928 } 929 if hdr.OpCode.IsControl() { 930 if err := controlHandler(hdr, &wsReader); err != nil { 931 return nil, err 932 } 933 continue 934 } 935 if hdr.OpCode&want == 0 { 936 if err := wsReader.Discard(); err != nil { 937 return nil, err 938 } 939 continue 940 } 941 if int64(cap(dst)) < hdr.Length+1 { 942 dst = make([]byte, 0, hdr.Length+hdr.Length>>3) 943 } 944 return readAllInto(dst[:0], &wsReader) 945 } 946 } 947 948 // Keep reusing the same buffer. 949 var msg []byte 950 for { 951 if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected { 952 cancel(ErrDisconnected) 953 return 954 } 955 if cap(msg) > readBufferSize*4 { 956 // Don't keep too much memory around. 957 msg = nil 958 } 959 960 var err error 961 msg, err = readDataInto(msg, conn, c.side, ws.OpBinary) 962 if err != nil { 963 cancel(ErrDisconnected) 964 logger.LogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF) 965 return 966 } 967 if c.incomingBytes != nil { 968 c.incomingBytes(int64(len(msg))) 969 } 970 971 // Parse the received message 972 var m message 973 subID, remain, err := m.parse(msg) 974 if err != nil { 975 logger.LogIf(ctx, fmt.Errorf("ws parse package: %w", err)) 976 cancel(ErrDisconnected) 977 return 978 } 979 if debugPrint { 980 fmt.Printf("%s Got msg: %v\n", c.Local, m) 981 } 982 if m.Op != OpMerged { 983 c.handleMsg(ctx, m, subID) 984 continue 985 } 986 // Handle merged messages. 987 messages := int(m.Seq) 988 for i := 0; i < messages; i++ { 989 if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected { 990 cancel(ErrDisconnected) 991 return 992 } 993 var next []byte 994 next, remain, err = msgp.ReadBytesZC(remain) 995 if err != nil { 996 logger.LogIf(ctx, fmt.Errorf("ws read merged: %w", err)) 997 cancel(ErrDisconnected) 998 return 999 } 1000 1001 m.Payload = nil 1002 subID, _, err = m.parse(next) 1003 if err != nil { 1004 logger.LogIf(ctx, fmt.Errorf("ws parse merged: %w", err)) 1005 cancel(ErrDisconnected) 1006 return 1007 } 1008 c.handleMsg(ctx, m, subID) 1009 } 1010 } 1011 }() 1012 1013 // Write function. 1014 defer func() { 1015 if rec := recover(); rec != nil { 1016 logger.LogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec)) 1017 debug.PrintStack() 1018 } 1019 if debugPrint { 1020 fmt.Println("handleMessages: write goroutine exited") 1021 } 1022 cancel(ErrDisconnected) 1023 c.connChange.L.Lock() 1024 if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) { 1025 c.connChange.Broadcast() 1026 } 1027 c.disconnected() 1028 c.connChange.L.Unlock() 1029 1030 conn.Close() 1031 c.handleMsgWg.Done() 1032 }() 1033 1034 c.connMu.Lock() 1035 connPingInterval := c.connPingInterval 1036 c.connMu.Unlock() 1037 ping := time.NewTicker(connPingInterval) 1038 pingFrame := message{ 1039 Op: OpPing, 1040 DeadlineMS: 5000, 1041 } 1042 1043 defer ping.Stop() 1044 queue := make([][]byte, 0, maxMergeMessages) 1045 merged := make([]byte, 0, writeBufferSize) 1046 var queueSize int 1047 var buf bytes.Buffer 1048 var wsw wsWriter 1049 for { 1050 var toSend []byte 1051 select { 1052 case <-ctx.Done(): 1053 return 1054 case <-ping.C: 1055 if c.State() != StateConnected { 1056 continue 1057 } 1058 lastPong := atomic.LoadInt64(&c.LastPong) 1059 if lastPong > 0 { 1060 lastPongTime := time.Unix(lastPong, 0) 1061 if d := time.Since(lastPongTime); d > connPingInterval*2 { 1062 logger.LogIf(ctx, fmt.Errorf("host %s last pong too old (%v); disconnecting", c.Remote, d.Round(time.Millisecond))) 1063 return 1064 } 1065 } 1066 var err error 1067 toSend, err = pingFrame.MarshalMsg(GetByteBuffer()[:0]) 1068 if err != nil { 1069 logger.LogIf(ctx, err) 1070 // Fake it... 1071 atomic.StoreInt64(&c.LastPong, time.Now().Unix()) 1072 continue 1073 } 1074 case toSend = <-c.outQueue: 1075 if len(toSend) == 0 { 1076 continue 1077 } 1078 } 1079 if len(queue) < maxMergeMessages && queueSize+len(toSend) < writeBufferSize-1024 && len(c.outQueue) > 0 { 1080 queue = append(queue, toSend) 1081 queueSize += len(toSend) 1082 continue 1083 } 1084 c.connChange.L.Lock() 1085 for { 1086 state := c.State() 1087 if state == StateConnected { 1088 break 1089 } 1090 if debugPrint { 1091 fmt.Println(c.Local, "Waiting for connection ->", c.Remote, "state: ", state) 1092 } 1093 if state == StateShutdown || state == StateConnectionError { 1094 c.connChange.L.Unlock() 1095 return 1096 } 1097 c.connChange.Wait() 1098 select { 1099 case <-ctx.Done(): 1100 c.connChange.L.Unlock() 1101 return 1102 default: 1103 } 1104 } 1105 c.connChange.L.Unlock() 1106 if len(queue) == 0 { 1107 // Combine writes. 1108 buf.Reset() 1109 err := wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend) 1110 if err != nil { 1111 logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err)) 1112 return 1113 } 1114 PutByteBuffer(toSend) 1115 err = conn.SetWriteDeadline(time.Now().Add(connWriteTimeout)) 1116 if err != nil { 1117 logger.LogIf(ctx, fmt.Errorf("conn.SetWriteDeadline: %w", err)) 1118 return 1119 } 1120 _, err = buf.WriteTo(conn) 1121 if err != nil { 1122 logger.LogIf(ctx, fmt.Errorf("ws write: %w", err)) 1123 return 1124 } 1125 continue 1126 } 1127 1128 // Merge entries and send 1129 queue = append(queue, toSend) 1130 if debugPrint { 1131 fmt.Println("Merging", len(queue), "messages") 1132 } 1133 1134 toSend = merged[:0] 1135 m := message{Op: OpMerged, Seq: uint32(len(queue))} 1136 var err error 1137 toSend, err = m.MarshalMsg(toSend) 1138 if err != nil { 1139 logger.LogIf(ctx, fmt.Errorf("msg.MarshalMsg: %w", err)) 1140 return 1141 } 1142 // Append as byte slices. 1143 for _, q := range queue { 1144 toSend = msgp.AppendBytes(toSend, q) 1145 PutByteBuffer(q) 1146 } 1147 queue = queue[:0] 1148 queueSize = 0 1149 1150 // Combine writes. 1151 // Consider avoiding buffer copy. 1152 buf.Reset() 1153 err = wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend) 1154 if err != nil { 1155 logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err)) 1156 return 1157 } 1158 // buf is our local buffer, so we can reuse it. 1159 err = conn.SetWriteDeadline(time.Now().Add(connWriteTimeout)) 1160 if err != nil { 1161 logger.LogIf(ctx, fmt.Errorf("conn.SetWriteDeadline: %w", err)) 1162 return 1163 } 1164 _, err = buf.WriteTo(conn) 1165 if err != nil { 1166 logger.LogIf(ctx, fmt.Errorf("ws write: %w", err)) 1167 return 1168 } 1169 1170 if buf.Cap() > writeBufferSize*4 { 1171 // Reset buffer if it gets too big, so we don't keep it around. 1172 buf = bytes.Buffer{} 1173 } 1174 } 1175 } 1176 1177 func (c *Connection) handleMsg(ctx context.Context, m message, subID *subHandlerID) { 1178 switch m.Op { 1179 case OpMuxServerMsg: 1180 c.handleMuxServerMsg(ctx, m) 1181 case OpResponse: 1182 c.handleResponse(m) 1183 case OpMuxClientMsg: 1184 c.handleMuxClientMsg(ctx, m) 1185 case OpUnblockSrvMux: 1186 c.handleUnblockSrvMux(m) 1187 case OpUnblockClMux: 1188 c.handleUnblockClMux(m) 1189 case OpDisconnectServerMux: 1190 c.handleDisconnectServerMux(m) 1191 case OpDisconnectClientMux: 1192 c.handleDisconnectClientMux(m) 1193 case OpPing: 1194 c.handlePing(ctx, m) 1195 case OpPong: 1196 c.handlePong(ctx, m) 1197 case OpRequest: 1198 c.handleRequest(ctx, m, subID) 1199 case OpAckMux: 1200 c.handleAckMux(ctx, m) 1201 case OpConnectMux: 1202 c.handleConnectMux(ctx, m, subID) 1203 case OpMuxConnectError: 1204 c.handleConnectMuxError(ctx, m) 1205 default: 1206 logger.LogIf(ctx, fmt.Errorf("unknown message type: %v", m.Op)) 1207 } 1208 } 1209 1210 func (c *Connection) handleConnectMux(ctx context.Context, m message, subID *subHandlerID) { 1211 // Stateless stream: 1212 if m.Flags&FlagStateless != 0 { 1213 // Reject for now, so we can safely add it later. 1214 if true { 1215 logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Stateless streams not supported"})) 1216 return 1217 } 1218 1219 var handler *StatelessHandler 1220 if subID == nil { 1221 handler = c.handlers.stateless[m.Handler] 1222 } else { 1223 handler = c.handlers.subStateless[*subID] 1224 } 1225 if handler == nil { 1226 logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"})) 1227 return 1228 } 1229 _, _ = c.inStream.LoadOrCompute(m.MuxID, func() *muxServer { 1230 return newMuxStateless(ctx, m, c, *handler) 1231 }) 1232 } else { 1233 // Stream: 1234 var handler *StreamHandler 1235 if subID == nil { 1236 if !m.Handler.valid() { 1237 logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler"})) 1238 return 1239 } 1240 handler = c.handlers.streams[m.Handler] 1241 } else { 1242 handler = c.handlers.subStreams[*subID] 1243 } 1244 if handler == nil { 1245 logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"})) 1246 return 1247 } 1248 1249 // Start a new server handler if none exists. 1250 _, _ = c.inStream.LoadOrCompute(m.MuxID, func() *muxServer { 1251 return newMuxStream(ctx, m, c, *handler) 1252 }) 1253 } 1254 } 1255 1256 // handleConnectMuxError when mux connect was rejected. 1257 func (c *Connection) handleConnectMuxError(ctx context.Context, m message) { 1258 if v, ok := c.outgoing.Load(m.MuxID); ok { 1259 var cErr muxConnectError 1260 _, err := cErr.UnmarshalMsg(m.Payload) 1261 logger.LogIf(ctx, err) 1262 v.error(RemoteErr(cErr.Error)) 1263 return 1264 } 1265 PutByteBuffer(m.Payload) 1266 } 1267 1268 func (c *Connection) handleAckMux(ctx context.Context, m message) { 1269 PutByteBuffer(m.Payload) 1270 v, ok := c.outgoing.Load(m.MuxID) 1271 if !ok { 1272 if m.Flags&FlagEOF == 0 { 1273 logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil)) 1274 } 1275 return 1276 } 1277 if debugPrint { 1278 fmt.Println(c.Local, "Mux", m.MuxID, "Acknowledged") 1279 } 1280 v.ack(m.Seq) 1281 } 1282 1283 func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHandlerID) { 1284 if !m.Handler.valid() { 1285 logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler"})) 1286 return 1287 } 1288 if debugReqs { 1289 fmt.Println(m.MuxID, c.StringReverse(), "INCOMING") 1290 } 1291 // Singleshot message 1292 var handler SingleHandlerFn 1293 if subID == nil { 1294 handler = c.handlers.single[m.Handler] 1295 } else { 1296 handler = c.handlers.subSingle[*subID] 1297 } 1298 if handler == nil { 1299 logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"})) 1300 return 1301 } 1302 1303 // TODO: This causes allocations, but escape analysis doesn't really show the cause. 1304 // If another faithful engineer wants to take a stab, feel free. 1305 go func(m message) { 1306 var start time.Time 1307 if m.DeadlineMS > 0 { 1308 start = time.Now() 1309 } 1310 var b []byte 1311 var err *RemoteErr 1312 func() { 1313 defer func() { 1314 if rec := recover(); rec != nil { 1315 err = NewRemoteErrString(fmt.Sprintf("handleMessages: panic recovered: %v", rec)) 1316 debug.PrintStack() 1317 logger.LogIf(ctx, err) 1318 } 1319 }() 1320 b, err = handler(m.Payload) 1321 if debugPrint { 1322 fmt.Println(c.Local, "Handler returned payload:", bytesOrLength(b), "err:", err) 1323 } 1324 }() 1325 1326 if m.DeadlineMS > 0 && time.Since(start).Milliseconds()+c.addDeadline.Milliseconds() > int64(m.DeadlineMS) { 1327 if debugReqs { 1328 fmt.Println(m.MuxID, c.StringReverse(), "DEADLINE EXCEEDED") 1329 } 1330 // No need to return result 1331 PutByteBuffer(b) 1332 return 1333 } 1334 if debugReqs { 1335 fmt.Println(m.MuxID, c.StringReverse(), "RESPONDING") 1336 } 1337 m = message{ 1338 MuxID: m.MuxID, 1339 Seq: m.Seq, 1340 Op: OpResponse, 1341 Flags: FlagEOF, 1342 } 1343 if err != nil { 1344 m.Flags |= FlagPayloadIsErr 1345 m.Payload = []byte(*err) 1346 } else { 1347 m.Payload = b 1348 m.setZeroPayloadFlag() 1349 } 1350 logger.LogIf(ctx, c.queueMsg(m, nil)) 1351 }(m) 1352 } 1353 1354 func (c *Connection) handlePong(ctx context.Context, m message) { 1355 var pong pongMsg 1356 _, err := pong.UnmarshalMsg(m.Payload) 1357 PutByteBuffer(m.Payload) 1358 logger.LogIf(ctx, err) 1359 if m.MuxID == 0 { 1360 atomic.StoreInt64(&c.LastPong, time.Now().Unix()) 1361 return 1362 } 1363 if v, ok := c.outgoing.Load(m.MuxID); ok { 1364 v.pong(pong) 1365 } else { 1366 // We don't care if the client was removed in the meantime, 1367 // but we send a disconnect message to the server just in case. 1368 logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil)) 1369 } 1370 } 1371 1372 func (c *Connection) handlePing(ctx context.Context, m message) { 1373 if m.MuxID == 0 { 1374 logger.LogIf(ctx, c.queueMsg(m, &pongMsg{})) 1375 return 1376 } 1377 // Single calls do not support pinging. 1378 if v, ok := c.inStream.Load(m.MuxID); ok { 1379 pong := v.ping(m.Seq) 1380 logger.LogIf(ctx, c.queueMsg(m, &pong)) 1381 } else { 1382 pong := pongMsg{NotFound: true} 1383 logger.LogIf(ctx, c.queueMsg(m, &pong)) 1384 } 1385 return 1386 } 1387 1388 func (c *Connection) handleDisconnectClientMux(m message) { 1389 if v, ok := c.outgoing.Load(m.MuxID); ok { 1390 if m.Flags&FlagPayloadIsErr != 0 { 1391 v.error(RemoteErr(m.Payload)) 1392 } else { 1393 v.error(ErrDisconnected) 1394 } 1395 return 1396 } 1397 PutByteBuffer(m.Payload) 1398 } 1399 1400 func (c *Connection) handleDisconnectServerMux(m message) { 1401 if debugPrint { 1402 fmt.Println(c.Local, "Disconnect server mux:", m.MuxID) 1403 } 1404 PutByteBuffer(m.Payload) 1405 m.Payload = nil 1406 if v, ok := c.inStream.Load(m.MuxID); ok { 1407 v.close() 1408 } 1409 } 1410 1411 func (c *Connection) handleUnblockClMux(m message) { 1412 PutByteBuffer(m.Payload) 1413 m.Payload = nil 1414 v, ok := c.outgoing.Load(m.MuxID) 1415 if !ok { 1416 if debugPrint { 1417 fmt.Println(c.Local, "Unblock: Unknown Mux:", m.MuxID) 1418 } 1419 // We can expect to receive unblocks for closed muxes 1420 return 1421 } 1422 v.unblockSend(m.Seq) 1423 } 1424 1425 func (c *Connection) handleUnblockSrvMux(m message) { 1426 if m.Payload != nil { 1427 PutByteBuffer(m.Payload) 1428 } 1429 m.Payload = nil 1430 if v, ok := c.inStream.Load(m.MuxID); ok { 1431 v.unblockSend(m.Seq) 1432 return 1433 } 1434 // We can expect to receive unblocks for closed muxes 1435 if debugPrint { 1436 fmt.Println(c.Local, "Unblock: Unknown Mux:", m.MuxID) 1437 } 1438 } 1439 1440 func (c *Connection) handleMuxClientMsg(ctx context.Context, m message) { 1441 v, ok := c.inStream.Load(m.MuxID) 1442 if !ok { 1443 if debugPrint { 1444 fmt.Println(c.Local, "OpMuxClientMsg: Unknown Mux:", m.MuxID) 1445 } 1446 logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil)) 1447 PutByteBuffer(m.Payload) 1448 return 1449 } 1450 v.message(m) 1451 } 1452 1453 func (c *Connection) handleResponse(m message) { 1454 if debugPrint { 1455 fmt.Printf("%s Got mux response: %v\n", c.Local, m) 1456 } 1457 v, ok := c.outgoing.Load(m.MuxID) 1458 if !ok { 1459 if debugReqs { 1460 fmt.Println(m.MuxID, c.String(), "Got response for unknown mux") 1461 } 1462 PutByteBuffer(m.Payload) 1463 return 1464 } 1465 if m.Flags&FlagPayloadIsErr != 0 { 1466 v.response(m.Seq, Response{ 1467 Msg: nil, 1468 Err: RemoteErr(m.Payload), 1469 }) 1470 PutByteBuffer(m.Payload) 1471 } else { 1472 v.response(m.Seq, Response{ 1473 Msg: m.Payload, 1474 Err: nil, 1475 }) 1476 } 1477 v.close() 1478 if debugReqs { 1479 fmt.Println(m.MuxID, c.String(), "handleResponse: closing mux") 1480 } 1481 } 1482 1483 func (c *Connection) handleMuxServerMsg(ctx context.Context, m message) { 1484 if debugPrint { 1485 fmt.Printf("%s Got mux msg: %v\n", c.Local, m) 1486 } 1487 v, ok := c.outgoing.Load(m.MuxID) 1488 if !ok { 1489 if m.Flags&FlagEOF == 0 { 1490 logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil)) 1491 } 1492 PutByteBuffer(m.Payload) 1493 return 1494 } 1495 if m.Flags&FlagPayloadIsErr != 0 { 1496 v.response(m.Seq, Response{ 1497 Msg: nil, 1498 Err: RemoteErr(m.Payload), 1499 }) 1500 PutByteBuffer(m.Payload) 1501 } else if m.Payload != nil { 1502 v.response(m.Seq, Response{ 1503 Msg: m.Payload, 1504 Err: nil, 1505 }) 1506 } 1507 if m.Flags&FlagEOF != 0 { 1508 if v.cancelFn != nil && m.Flags&FlagPayloadIsErr == 0 { 1509 v.cancelFn(errStreamEOF) 1510 } 1511 v.close() 1512 if debugReqs { 1513 fmt.Println(m.MuxID, c.String(), "handleMuxServerMsg: DELETING MUX") 1514 } 1515 c.outgoing.Delete(m.MuxID) 1516 } 1517 } 1518 1519 func (c *Connection) deleteMux(incoming bool, muxID uint64) { 1520 if incoming { 1521 if debugPrint { 1522 fmt.Println("deleteMux: disconnect incoming mux", muxID) 1523 } 1524 v, loaded := c.inStream.LoadAndDelete(muxID) 1525 if loaded && v != nil { 1526 logger.LogIf(c.ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: muxID}, nil)) 1527 v.close() 1528 } 1529 } else { 1530 if debugPrint { 1531 fmt.Println("deleteMux: disconnect outgoing mux", muxID) 1532 } 1533 v, loaded := c.outgoing.LoadAndDelete(muxID) 1534 if loaded && v != nil { 1535 if debugReqs { 1536 fmt.Println(muxID, c.String(), "deleteMux: DELETING MUX") 1537 } 1538 v.close() 1539 logger.LogIf(c.ctx, c.queueMsg(message{Op: OpDisconnectServerMux, MuxID: muxID}, nil)) 1540 } 1541 } 1542 } 1543 1544 // State returns the current connection status. 1545 func (c *Connection) State() State { 1546 return State(atomic.LoadUint32((*uint32)(&c.state))) 1547 } 1548 1549 // Stats returns the current connection stats. 1550 func (c *Connection) Stats() ConnectionStats { 1551 return ConnectionStats{ 1552 IncomingStreams: c.inStream.Size(), 1553 OutgoingStreams: c.outgoing.Size(), 1554 } 1555 } 1556 1557 func (c *Connection) debugMsg(d debugMsg, args ...any) { 1558 if debugPrint { 1559 fmt.Println("debug: sending message", d, args) 1560 } 1561 1562 switch d { 1563 case debugShutdown: 1564 c.updateState(StateShutdown) 1565 case debugKillInbound: 1566 c.connMu.Lock() 1567 defer c.connMu.Unlock() 1568 if c.debugInConn != nil { 1569 if debugPrint { 1570 fmt.Println("debug: closing inbound connection") 1571 } 1572 c.debugInConn.Close() 1573 } 1574 case debugKillOutbound: 1575 c.connMu.Lock() 1576 defer c.connMu.Unlock() 1577 if c.debugInConn != nil { 1578 if debugPrint { 1579 fmt.Println("debug: closing outgoing connection") 1580 } 1581 c.debugInConn.Close() 1582 } 1583 case debugWaitForExit: 1584 c.reconnectMu.Lock() 1585 c.handleMsgWg.Wait() 1586 c.reconnectMu.Unlock() 1587 case debugSetConnPingDuration: 1588 c.connMu.Lock() 1589 defer c.connMu.Unlock() 1590 c.connPingInterval = args[0].(time.Duration) 1591 case debugSetClientPingDuration: 1592 c.clientPingInterval = args[0].(time.Duration) 1593 case debugAddToDeadline: 1594 c.addDeadline = args[0].(time.Duration) 1595 case debugIsOutgoingClosed: 1596 // params: muxID uint64, isClosed func(bool) 1597 muxID := args[0].(uint64) 1598 resp := args[1].(func(b bool)) 1599 mid, ok := c.outgoing.Load(muxID) 1600 if !ok || mid == nil { 1601 resp(true) 1602 return 1603 } 1604 mid.respMu.Lock() 1605 resp(mid.closed) 1606 mid.respMu.Unlock() 1607 } 1608 } 1609 1610 // wsWriter writes websocket messages. 1611 type wsWriter struct { 1612 tmp [ws.MaxHeaderSize]byte 1613 } 1614 1615 // writeMessage writes a message to w without allocations. 1616 func (ww *wsWriter) writeMessage(w io.Writer, s ws.State, op ws.OpCode, p []byte) error { 1617 const fin = true 1618 var frame ws.Frame 1619 if s.ClientSide() { 1620 // We do not need to copy the payload, since we own it. 1621 payload := p 1622 1623 frame = ws.NewFrame(op, fin, payload) 1624 frame = ws.MaskFrameInPlace(frame) 1625 } else { 1626 frame = ws.NewFrame(op, fin, p) 1627 } 1628 1629 return ww.writeFrame(w, frame) 1630 } 1631 1632 // writeFrame writes frame binary representation into w. 1633 func (ww *wsWriter) writeFrame(w io.Writer, f ws.Frame) error { 1634 const ( 1635 bit0 = 0x80 1636 len7 = int64(125) 1637 len16 = int64(^(uint16(0))) 1638 len64 = int64(^(uint64(0)) >> 1) 1639 ) 1640 1641 bts := ww.tmp[:] 1642 if f.Header.Fin { 1643 bts[0] |= bit0 1644 } 1645 bts[0] |= f.Header.Rsv << 4 1646 bts[0] |= byte(f.Header.OpCode) 1647 1648 var n int 1649 switch { 1650 case f.Header.Length <= len7: 1651 bts[1] = byte(f.Header.Length) 1652 n = 2 1653 1654 case f.Header.Length <= len16: 1655 bts[1] = 126 1656 binary.BigEndian.PutUint16(bts[2:4], uint16(f.Header.Length)) 1657 n = 4 1658 1659 case f.Header.Length <= len64: 1660 bts[1] = 127 1661 binary.BigEndian.PutUint64(bts[2:10], uint64(f.Header.Length)) 1662 n = 10 1663 1664 default: 1665 return ws.ErrHeaderLengthUnexpected 1666 } 1667 1668 if f.Header.Masked { 1669 bts[1] |= bit0 1670 n += copy(bts[n:], f.Header.Mask[:]) 1671 } 1672 1673 if _, err := w.Write(bts[:n]); err != nil { 1674 return err 1675 } 1676 1677 _, err := w.Write(f.Payload) 1678 return err 1679 }