github.com/minio/minio@v0.0.0-20240328213742-3f72439b8a27/internal/grid/connection.go (about)

     1  // Copyright (c) 2015-2023 MinIO, Inc.
     2  //
     3  // This file is part of MinIO Object Storage stack
     4  //
     5  // This program is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Affero General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // This program is distributed in the hope that it will be useful
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13  // GNU Affero General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Affero General Public License
    16  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17  
    18  package grid
    19  
    20  import (
    21  	"bytes"
    22  	"context"
    23  	"crypto/tls"
    24  	"encoding/binary"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"math"
    29  	"math/rand"
    30  	"net"
    31  	"net/http"
    32  	"runtime/debug"
    33  	"strings"
    34  	"sync"
    35  	"sync/atomic"
    36  	"time"
    37  
    38  	"github.com/gobwas/ws"
    39  	"github.com/gobwas/ws/wsutil"
    40  	"github.com/google/uuid"
    41  	"github.com/minio/madmin-go/v3"
    42  	xioutil "github.com/minio/minio/internal/ioutil"
    43  	"github.com/minio/minio/internal/logger"
    44  	"github.com/minio/minio/internal/pubsub"
    45  	"github.com/puzpuzpuz/xsync/v3"
    46  	"github.com/tinylib/msgp/msgp"
    47  	"github.com/zeebo/xxh3"
    48  )
    49  
    50  // A Connection is a remote connection.
    51  // There is no distinction externally whether the connection was initiated from
    52  // this server or from the remote.
    53  type Connection struct {
    54  	// NextID is the next ID that can be used (atomic).
    55  	NextID uint64
    56  
    57  	// LastPong is last pong time (atomic)
    58  	// Only valid when StateConnected.
    59  	LastPong int64
    60  
    61  	// State of the connection (atomic)
    62  	state State
    63  
    64  	// Non-atomic
    65  	Remote string
    66  	Local  string
    67  
    68  	// ID of this connection instance.
    69  	id uuid.UUID
    70  
    71  	// Remote uuid, if we have been connected.
    72  	remoteID    *uuid.UUID
    73  	reconnectMu sync.Mutex
    74  
    75  	// Context for the server.
    76  	ctx context.Context
    77  
    78  	// Active mux connections.
    79  	outgoing *xsync.MapOf[uint64, *muxClient]
    80  
    81  	// Incoming streams
    82  	inStream *xsync.MapOf[uint64, *muxServer]
    83  
    84  	// outQueue is the output queue
    85  	outQueue chan []byte
    86  
    87  	// Client or serverside.
    88  	side ws.State
    89  
    90  	// Transport for outgoing connections.
    91  	dialer ContextDialer
    92  	header http.Header
    93  
    94  	handleMsgWg sync.WaitGroup
    95  
    96  	// connChange will be signaled whenever State has been updated, or at regular intervals.
    97  	// Holding the lock allows safe reads of State, and guarantees that changes will be detected.
    98  	connChange *sync.Cond
    99  	handlers   *handlers
   100  
   101  	remote             *RemoteClient
   102  	auth               AuthFn
   103  	clientPingInterval time.Duration
   104  	connPingInterval   time.Duration
   105  	tlsConfig          *tls.Config
   106  	blockConnect       chan struct{}
   107  
   108  	incomingBytes func(n int64) // Record incoming bytes.
   109  	outgoingBytes func(n int64) // Record outgoing bytes.
   110  	trace         *tracer       // tracer for this connection.
   111  	baseFlags     Flags
   112  
   113  	// For testing only
   114  	debugInConn  net.Conn
   115  	debugOutConn net.Conn
   116  	addDeadline  time.Duration
   117  	connMu       sync.Mutex
   118  }
   119  
   120  // Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID.
   121  type Subroute struct {
   122  	*Connection
   123  	trace *tracer
   124  	route string
   125  	subID subHandlerID
   126  }
   127  
   128  // String returns a string representation of the connection.
   129  func (c *Connection) String() string {
   130  	return fmt.Sprintf("%s->%s", c.Local, c.Remote)
   131  }
   132  
   133  // StringReverse returns a string representation of the reverse connection.
   134  func (c *Connection) StringReverse() string {
   135  	return fmt.Sprintf("%s->%s", c.Remote, c.Local)
   136  }
   137  
   138  // State is a connection state.
   139  type State uint32
   140  
   141  // MANUAL go:generate stringer -type=State -output=state_string.go -trimprefix=State $GOFILE
   142  
   143  const (
   144  	// StateUnconnected is the initial state of a connection.
   145  	// When the first message is sent it will attempt to connect.
   146  	StateUnconnected = iota
   147  
   148  	// StateConnecting is the state from StateUnconnected while the connection is attempted to be established.
   149  	// After this connection will be StateConnected or StateConnectionError.
   150  	StateConnecting
   151  
   152  	// StateConnected is the state when the connection has been established and is considered stable.
   153  	// If the connection is lost, state will switch to StateConnecting.
   154  	StateConnected
   155  
   156  	// StateConnectionError is the state once a connection attempt has been made, and it failed.
   157  	// The connection will remain in this stat until the connection has been successfully re-established.
   158  	StateConnectionError
   159  
   160  	// StateShutdown is the state when the server has been shut down.
   161  	// This will not be used under normal operation.
   162  	StateShutdown
   163  
   164  	// MaxDeadline is the maximum deadline allowed,
   165  	// Approx 49 days.
   166  	MaxDeadline = time.Duration(math.MaxUint32) * time.Millisecond
   167  )
   168  
   169  // ContextDialer is a dialer that can be used to dial a remote.
   170  type ContextDialer func(ctx context.Context, network, address string) (net.Conn, error)
   171  
   172  // DialContext implements the Dialer interface.
   173  func (c ContextDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
   174  	return c(ctx, network, address)
   175  }
   176  
   177  const (
   178  	defaultOutQueue    = 65535    // kind of close to max open fds per user
   179  	readBufferSize     = 32 << 10 // 32 KiB is the most optimal on Linux
   180  	writeBufferSize    = 32 << 10 // 32 KiB is the most optimal on Linux
   181  	defaultDialTimeout = 2 * time.Second
   182  	connPingInterval   = 10 * time.Second
   183  	connWriteTimeout   = 3 * time.Second
   184  )
   185  
   186  type connectionParams struct {
   187  	ctx           context.Context
   188  	id            uuid.UUID
   189  	local, remote string
   190  	dial          ContextDialer
   191  	handlers      *handlers
   192  	auth          AuthFn
   193  	tlsConfig     *tls.Config
   194  	incomingBytes func(n int64) // Record incoming bytes.
   195  	outgoingBytes func(n int64) // Record outgoing bytes.
   196  	publisher     *pubsub.PubSub[madmin.TraceInfo, madmin.TraceType]
   197  
   198  	blockConnect chan struct{}
   199  }
   200  
   201  // newConnection will create an unconnected connection to a remote.
   202  func newConnection(o connectionParams) *Connection {
   203  	c := &Connection{
   204  		state:              StateUnconnected,
   205  		Remote:             o.remote,
   206  		Local:              o.local,
   207  		id:                 o.id,
   208  		ctx:                o.ctx,
   209  		outgoing:           xsync.NewMapOfPresized[uint64, *muxClient](1000),
   210  		inStream:           xsync.NewMapOfPresized[uint64, *muxServer](1000),
   211  		outQueue:           make(chan []byte, defaultOutQueue),
   212  		dialer:             o.dial,
   213  		side:               ws.StateServerSide,
   214  		connChange:         &sync.Cond{L: &sync.Mutex{}},
   215  		handlers:           o.handlers,
   216  		auth:               o.auth,
   217  		header:             make(http.Header, 1),
   218  		remote:             &RemoteClient{Name: o.remote},
   219  		clientPingInterval: clientPingInterval,
   220  		connPingInterval:   connPingInterval,
   221  		tlsConfig:          o.tlsConfig,
   222  		incomingBytes:      o.incomingBytes,
   223  		outgoingBytes:      o.outgoingBytes,
   224  	}
   225  	if debugPrint {
   226  		// Random Mux ID
   227  		c.NextID = rand.Uint64()
   228  	}
   229  	if !strings.HasPrefix(o.remote, "https://") && !strings.HasPrefix(o.remote, "wss://") {
   230  		c.baseFlags |= FlagCRCxxh3
   231  	}
   232  	if !strings.HasPrefix(o.local, "https://") && !strings.HasPrefix(o.local, "wss://") {
   233  		c.baseFlags |= FlagCRCxxh3
   234  	}
   235  	if o.publisher != nil {
   236  		c.traceRequests(o.publisher)
   237  	}
   238  	if o.local == o.remote {
   239  		panic("equal hosts")
   240  	}
   241  	if c.shouldConnect() {
   242  		c.side = ws.StateClientSide
   243  
   244  		go func() {
   245  			if o.blockConnect != nil {
   246  				<-o.blockConnect
   247  			}
   248  			c.connect()
   249  		}()
   250  	}
   251  	if debugPrint {
   252  		fmt.Println(c.Local, "->", c.Remote, "Should local connect:", c.shouldConnect(), "side:", c.side)
   253  	}
   254  	if debugReqs {
   255  		fmt.Println("Created connection", c.String())
   256  	}
   257  	return c
   258  }
   259  
   260  // Subroute returns a static subroute for the connection.
   261  func (c *Connection) Subroute(s string) *Subroute {
   262  	if c == nil {
   263  		return nil
   264  	}
   265  	return &Subroute{
   266  		Connection: c,
   267  		route:      s,
   268  		subID:      makeSubHandlerID(0, s),
   269  		trace:      c.trace.subroute(s),
   270  	}
   271  }
   272  
   273  // Subroute adds a subroute to the subroute.
   274  // The subroutes are combined with '/'.
   275  func (c *Subroute) Subroute(s string) *Subroute {
   276  	route := strings.Join([]string{c.route, s}, "/")
   277  	return &Subroute{
   278  		Connection: c.Connection,
   279  		route:      route,
   280  		subID:      makeSubHandlerID(0, route),
   281  		trace:      c.trace.subroute(route),
   282  	}
   283  }
   284  
   285  // newMuxClient returns a mux client for manual use.
   286  func (c *Connection) newMuxClient(ctx context.Context) (*muxClient, error) {
   287  	client := newMuxClient(ctx, atomic.AddUint64(&c.NextID, 1), c)
   288  	if dl, ok := ctx.Deadline(); ok {
   289  		client.deadline = getDeadline(time.Until(dl))
   290  		if client.deadline == 0 {
   291  			client.cancelFn(context.DeadlineExceeded)
   292  			return nil, context.DeadlineExceeded
   293  		}
   294  	}
   295  	for {
   296  		// Handle the extremely unlikely scenario that we wrapped.
   297  		if _, loaded := c.outgoing.LoadOrStore(client.MuxID, client); client.MuxID != 0 && !loaded {
   298  			if debugReqs {
   299  				_, found := c.outgoing.Load(client.MuxID)
   300  				fmt.Println(client.MuxID, c.String(), "Connection.newMuxClient: RELOADED MUX. loaded:", loaded, "found:", found)
   301  			}
   302  			return client, nil
   303  		}
   304  		client.MuxID = atomic.AddUint64(&c.NextID, 1)
   305  	}
   306  }
   307  
   308  // newMuxClient returns a mux client for manual use.
   309  func (c *Subroute) newMuxClient(ctx context.Context) (*muxClient, error) {
   310  	cl, err := c.Connection.newMuxClient(ctx)
   311  	if err != nil {
   312  		return nil, err
   313  	}
   314  	cl.subroute = &c.subID
   315  	return cl, nil
   316  }
   317  
   318  // Request allows to do a single remote request.
   319  // 'req' will not be used after the call and caller can reuse.
   320  // If no deadline is set on ctx, a 1-minute deadline will be added.
   321  func (c *Connection) Request(ctx context.Context, h HandlerID, req []byte) ([]byte, error) {
   322  	if !h.valid() {
   323  		return nil, ErrUnknownHandler
   324  	}
   325  	if c.State() != StateConnected {
   326  		return nil, ErrDisconnected
   327  	}
   328  	// Create mux client and call.
   329  	client, err := c.newMuxClient(ctx)
   330  	if err != nil {
   331  		return nil, err
   332  	}
   333  	defer func() {
   334  		if debugReqs {
   335  			_, ok := c.outgoing.Load(client.MuxID)
   336  			fmt.Println(client.MuxID, c.String(), "Connection.Request: DELETING MUX. Exists:", ok)
   337  		}
   338  		client.cancelFn(context.Canceled)
   339  		c.outgoing.Delete(client.MuxID)
   340  	}()
   341  	return client.traceRoundtrip(ctx, c.trace, h, req)
   342  }
   343  
   344  // Request allows to do a single remote request.
   345  // 'req' will not be used after the call and caller can reuse.
   346  // If no deadline is set on ctx, a 1-minute deadline will be added.
   347  func (c *Subroute) Request(ctx context.Context, h HandlerID, req []byte) ([]byte, error) {
   348  	if !h.valid() {
   349  		return nil, ErrUnknownHandler
   350  	}
   351  	if c.State() != StateConnected {
   352  		return nil, ErrDisconnected
   353  	}
   354  	// Create mux client and call.
   355  	client, err := c.newMuxClient(ctx)
   356  	if err != nil {
   357  		return nil, err
   358  	}
   359  	client.subroute = &c.subID
   360  	defer func() {
   361  		if debugReqs {
   362  			fmt.Println(client.MuxID, c.String(), "Subroute.Request: DELETING MUX")
   363  		}
   364  		client.cancelFn(context.Canceled)
   365  		c.outgoing.Delete(client.MuxID)
   366  	}()
   367  	return client.traceRoundtrip(ctx, c.trace, h, req)
   368  }
   369  
   370  // NewStream creates a new stream.
   371  // Initial payload can be reused by the caller.
   372  func (c *Connection) NewStream(ctx context.Context, h HandlerID, payload []byte) (st *Stream, err error) {
   373  	if !h.valid() {
   374  		return nil, ErrUnknownHandler
   375  	}
   376  	if c.State() != StateConnected {
   377  		return nil, ErrDisconnected
   378  	}
   379  	handler := c.handlers.streams[h]
   380  	if handler == nil {
   381  		return nil, ErrUnknownHandler
   382  	}
   383  
   384  	var requests chan []byte
   385  	var responses chan Response
   386  	if handler.InCapacity > 0 {
   387  		requests = make(chan []byte, handler.InCapacity)
   388  	}
   389  	if handler.OutCapacity > 0 {
   390  		responses = make(chan Response, handler.OutCapacity)
   391  	} else {
   392  		responses = make(chan Response, 1)
   393  	}
   394  
   395  	cl, err := c.newMuxClient(ctx)
   396  	if err != nil {
   397  		return nil, err
   398  	}
   399  
   400  	return cl.RequestStream(h, payload, requests, responses)
   401  }
   402  
   403  // NewStream creates a new stream.
   404  // Initial payload can be reused by the caller.
   405  func (c *Subroute) NewStream(ctx context.Context, h HandlerID, payload []byte) (st *Stream, err error) {
   406  	if !h.valid() {
   407  		return nil, ErrUnknownHandler
   408  	}
   409  	if c.State() != StateConnected {
   410  		return nil, ErrDisconnected
   411  	}
   412  	handler := c.handlers.subStreams[makeZeroSubHandlerID(h)]
   413  	if handler == nil {
   414  		if debugPrint {
   415  			fmt.Println("want", makeZeroSubHandlerID(h), c.route, "got", c.handlers.subStreams)
   416  		}
   417  		return nil, ErrUnknownHandler
   418  	}
   419  
   420  	var requests chan []byte
   421  	var responses chan Response
   422  	if handler.InCapacity > 0 {
   423  		requests = make(chan []byte, handler.InCapacity)
   424  	}
   425  	if handler.OutCapacity > 0 {
   426  		responses = make(chan Response, handler.OutCapacity)
   427  	} else {
   428  		responses = make(chan Response, 1)
   429  	}
   430  
   431  	cl, err := c.newMuxClient(ctx)
   432  	if err != nil {
   433  		return nil, err
   434  	}
   435  	cl.subroute = &c.subID
   436  
   437  	return cl.RequestStream(h, payload, requests, responses)
   438  }
   439  
   440  // WaitForConnect will block until a connection has been established or
   441  // the context is canceled, in which case the context error is returned.
   442  func (c *Connection) WaitForConnect(ctx context.Context) error {
   443  	if debugPrint {
   444  		fmt.Println(c.Local, "->", c.Remote, "WaitForConnect")
   445  		defer fmt.Println(c.Local, "->", c.Remote, "WaitForConnect done")
   446  	}
   447  	c.connChange.L.Lock()
   448  	if atomic.LoadUint32((*uint32)(&c.state)) == StateConnected {
   449  		c.connChange.L.Unlock()
   450  		// Happy path.
   451  		return nil
   452  	}
   453  	ctx, cancel := context.WithCancel(ctx)
   454  	defer cancel()
   455  	changed := make(chan State, 1)
   456  	go func() {
   457  		defer xioutil.SafeClose(changed)
   458  		for {
   459  			c.connChange.Wait()
   460  			newState := c.State()
   461  			select {
   462  			case changed <- newState:
   463  				if newState == StateConnected || newState == StateShutdown {
   464  					c.connChange.L.Unlock()
   465  					return
   466  				}
   467  			case <-ctx.Done():
   468  				c.connChange.L.Unlock()
   469  				return
   470  			}
   471  		}
   472  	}()
   473  
   474  	for {
   475  		select {
   476  		case <-ctx.Done():
   477  			return context.Cause(ctx)
   478  		case newState := <-changed:
   479  			if newState == StateConnected {
   480  				return nil
   481  			}
   482  		}
   483  	}
   484  }
   485  
   486  /*
   487  var ErrDone = errors.New("done for now")
   488  
   489  var ErrRemoteRestart = errors.New("remote restarted")
   490  
   491  
   492  // Stateless connects to the remote handler and return all packets sent back.
   493  // If the remote is restarted will return ErrRemoteRestart.
   494  // If nil will be returned remote call sent EOF or ErrDone is returned by the callback.
   495  // If ErrDone is returned on cb nil will be returned.
   496  func (c *Connection) Stateless(ctx context.Context, h HandlerID, req []byte, cb func([]byte) error) error {
   497  	client, err := c.newMuxClient(ctx)
   498  	if err != nil {
   499  		return err
   500  	}
   501  	defer c.outgoing.Delete(client.MuxID)
   502  	resp := make(chan Response, 10)
   503  	client.RequestStateless(h, req, resp)
   504  
   505  	for r := range resp {
   506  		if r.Err != nil {
   507  			return r.Err
   508  		}
   509  		if len(r.Msg) > 0 {
   510  			err := cb(r.Msg)
   511  			if err != nil {
   512  				if errors.Is(err, ErrDone) {
   513  					break
   514  				}
   515  				return err
   516  			}
   517  		}
   518  	}
   519  	return nil
   520  }
   521  */
   522  
   523  // shouldConnect returns a deterministic bool whether the local should initiate the connection.
   524  // It should be 50% chance of any host initiating the connection.
   525  func (c *Connection) shouldConnect() bool {
   526  	// The remote should have the opposite result.
   527  	h0 := xxh3.HashString(c.Local + c.Remote)
   528  	h1 := xxh3.HashString(c.Remote + c.Local)
   529  	if h0 == h1 {
   530  		return c.Local < c.Remote
   531  	}
   532  	return h0 < h1
   533  }
   534  
   535  func (c *Connection) send(ctx context.Context, msg []byte) error {
   536  	select {
   537  	case <-ctx.Done():
   538  		// Returning error here is too noisy.
   539  		return nil
   540  	case c.outQueue <- msg:
   541  		return nil
   542  	}
   543  }
   544  
   545  // queueMsg queues a message, with an optional payload.
   546  // sender should not reference msg.Payload
   547  func (c *Connection) queueMsg(msg message, payload sender) error {
   548  	// Add baseflags.
   549  	msg.Flags.Set(c.baseFlags)
   550  	// This cannot encode subroute.
   551  	msg.Flags.Clear(FlagSubroute)
   552  	if payload != nil {
   553  		if cap(msg.Payload) < payload.Msgsize() {
   554  			old := msg.Payload
   555  			msg.Payload = GetByteBuffer()[:0]
   556  			PutByteBuffer(old)
   557  		}
   558  		var err error
   559  		msg.Payload, err = payload.MarshalMsg(msg.Payload[:0])
   560  		msg.Op = payload.Op()
   561  		if err != nil {
   562  			return err
   563  		}
   564  	}
   565  	defer PutByteBuffer(msg.Payload)
   566  	dst := GetByteBuffer()[:0]
   567  	dst, err := msg.MarshalMsg(dst)
   568  	if err != nil {
   569  		return err
   570  	}
   571  	if msg.Flags&FlagCRCxxh3 != 0 {
   572  		h := xxh3.Hash(dst)
   573  		dst = binary.LittleEndian.AppendUint32(dst, uint32(h))
   574  	}
   575  	return c.send(c.ctx, dst)
   576  }
   577  
   578  // sendMsg will send
   579  func (c *Connection) sendMsg(conn net.Conn, msg message, payload msgp.MarshalSizer) error {
   580  	if payload != nil {
   581  		if cap(msg.Payload) < payload.Msgsize() {
   582  			PutByteBuffer(msg.Payload)
   583  			msg.Payload = GetByteBuffer()[:0]
   584  		}
   585  		var err error
   586  		msg.Payload, err = payload.MarshalMsg(msg.Payload)
   587  		if err != nil {
   588  			return err
   589  		}
   590  		defer PutByteBuffer(msg.Payload)
   591  	}
   592  	dst := GetByteBuffer()[:0]
   593  	dst, err := msg.MarshalMsg(dst)
   594  	if err != nil {
   595  		return err
   596  	}
   597  	if msg.Flags&FlagCRCxxh3 != 0 {
   598  		h := xxh3.Hash(dst)
   599  		dst = binary.LittleEndian.AppendUint32(dst, uint32(h))
   600  	}
   601  	if debugPrint {
   602  		fmt.Println(c.Local, "sendMsg: Sending", msg.Op, "as", len(dst), "bytes")
   603  	}
   604  	if c.outgoingBytes != nil {
   605  		c.outgoingBytes(int64(len(dst)))
   606  	}
   607  	err = conn.SetWriteDeadline(time.Now().Add(connWriteTimeout))
   608  	if err != nil {
   609  		return err
   610  	}
   611  	return wsutil.WriteMessage(conn, c.side, ws.OpBinary, dst)
   612  }
   613  
   614  func (c *Connection) connect() {
   615  	c.updateState(StateConnecting)
   616  	rng := rand.New(rand.NewSource(time.Now().UnixNano()))
   617  	// Runs until the server is shut down.
   618  	for {
   619  		if c.State() == StateShutdown {
   620  			return
   621  		}
   622  		toDial := strings.Replace(c.Remote, "http://", "ws://", 1)
   623  		toDial = strings.Replace(toDial, "https://", "wss://", 1)
   624  		toDial += RoutePath
   625  
   626  		dialer := ws.DefaultDialer
   627  		dialer.ReadBufferSize = readBufferSize
   628  		dialer.WriteBufferSize = writeBufferSize
   629  		dialer.Timeout = defaultDialTimeout
   630  		if c.dialer != nil {
   631  			dialer.NetDial = c.dialer.DialContext
   632  		}
   633  		if c.header == nil {
   634  			c.header = make(http.Header, 2)
   635  		}
   636  		c.header.Set("Authorization", "Bearer "+c.auth(""))
   637  		c.header.Set("X-Minio-Time", time.Now().UTC().Format(time.RFC3339))
   638  
   639  		if len(c.header) > 0 {
   640  			dialer.Header = ws.HandshakeHeaderHTTP(c.header)
   641  		}
   642  		dialer.TLSConfig = c.tlsConfig
   643  		dialStarted := time.Now()
   644  		if debugPrint {
   645  			fmt.Println(c.Local, "Connecting to ", toDial)
   646  		}
   647  		conn, br, _, err := dialer.Dial(c.ctx, toDial)
   648  		if br != nil {
   649  			ws.PutReader(br)
   650  		}
   651  		c.connMu.Lock()
   652  		c.debugOutConn = conn
   653  		c.connMu.Unlock()
   654  		retry := func(err error) {
   655  			if debugPrint {
   656  				fmt.Printf("%v Connecting to %v: %v. Retrying.\n", c.Local, toDial, err)
   657  			}
   658  			sleep := defaultDialTimeout + time.Duration(rng.Int63n(int64(defaultDialTimeout)))
   659  			next := dialStarted.Add(sleep / 2)
   660  			sleep = time.Until(next).Round(time.Millisecond)
   661  			if sleep < 0 {
   662  				sleep = 0
   663  			}
   664  			gotState := c.State()
   665  			if gotState == StateShutdown {
   666  				return
   667  			}
   668  			if gotState != StateConnecting {
   669  				// Don't print error on first attempt,
   670  				// and after that only once per hour.
   671  				logger.LogOnceIf(c.ctx, fmt.Errorf("grid: %s connecting to %s: %w (%T) Sleeping %v (%v)", c.Local, toDial, err, err, sleep, gotState), toDial)
   672  			}
   673  			c.updateState(StateConnectionError)
   674  			time.Sleep(sleep)
   675  		}
   676  		if err != nil {
   677  			retry(err)
   678  			continue
   679  		}
   680  		// Send connect message.
   681  		m := message{
   682  			Op: OpConnect,
   683  		}
   684  		req := connectReq{
   685  			Host: c.Local,
   686  			ID:   c.id,
   687  		}
   688  		err = c.sendMsg(conn, m, &req)
   689  		if err != nil {
   690  			retry(err)
   691  			continue
   692  		}
   693  		// Wait for response
   694  		var r connectResp
   695  		err = c.receive(conn, &r)
   696  		if err != nil {
   697  			if debugPrint {
   698  				fmt.Println(c.Local, "receive err:", err, "side:", c.side)
   699  			}
   700  			retry(err)
   701  			continue
   702  		}
   703  		if debugPrint {
   704  			fmt.Println(c.Local, "Got connectResp:", r)
   705  		}
   706  		if !r.Accepted {
   707  			retry(fmt.Errorf("connection rejected: %s", r.RejectedReason))
   708  			continue
   709  		}
   710  		c.reconnectMu.Lock()
   711  		remoteUUID := uuid.UUID(r.ID)
   712  		if c.remoteID != nil {
   713  			c.reconnected()
   714  		}
   715  		c.remoteID = &remoteUUID
   716  		if debugPrint {
   717  			fmt.Println(c.Local, "Connected Waiting for Messages")
   718  		}
   719  		// Handle messages...
   720  		c.handleMessages(c.ctx, conn)
   721  		// Reconnect unless we are shutting down (debug only).
   722  		if c.State() == StateShutdown {
   723  			conn.Close()
   724  			return
   725  		}
   726  		if debugPrint {
   727  			fmt.Println(c.Local, "Disconnected. Attempting to reconnect.")
   728  		}
   729  	}
   730  }
   731  
   732  func (c *Connection) disconnected() {
   733  	c.outgoing.Range(func(key uint64, client *muxClient) bool {
   734  		if !client.stateless {
   735  			client.cancelFn(ErrDisconnected)
   736  		}
   737  		return true
   738  	})
   739  	if debugReqs {
   740  		fmt.Println(c.String(), "Disconnected. Clearing outgoing.")
   741  	}
   742  	c.outgoing.Clear()
   743  	c.inStream.Range(func(key uint64, client *muxServer) bool {
   744  		client.cancel()
   745  		return true
   746  	})
   747  	c.inStream.Clear()
   748  }
   749  
   750  func (c *Connection) receive(conn net.Conn, r receiver) error {
   751  	b, op, err := wsutil.ReadData(conn, c.side)
   752  	if err != nil {
   753  		return err
   754  	}
   755  	if op != ws.OpBinary {
   756  		return fmt.Errorf("unexpected connect response type %v", op)
   757  	}
   758  	if c.incomingBytes != nil {
   759  		c.incomingBytes(int64(len(b)))
   760  	}
   761  
   762  	var m message
   763  	_, _, err = m.parse(b)
   764  	if err != nil {
   765  		return err
   766  	}
   767  	if m.Op != r.Op() {
   768  		return fmt.Errorf("unexpected response OP, want %v, got %v", r.Op(), m.Op)
   769  	}
   770  	_, err = r.UnmarshalMsg(m.Payload)
   771  	return err
   772  }
   773  
   774  func (c *Connection) handleIncoming(ctx context.Context, conn net.Conn, req connectReq) error {
   775  	c.connMu.Lock()
   776  	c.debugInConn = conn
   777  	c.connMu.Unlock()
   778  	if c.blockConnect != nil {
   779  		// Block until we are allowed to connect.
   780  		<-c.blockConnect
   781  	}
   782  	if req.Host != c.Remote {
   783  		err := fmt.Errorf("expected remote '%s', got '%s'", c.Remote, req.Host)
   784  		if debugPrint {
   785  			fmt.Println(err)
   786  		}
   787  		return err
   788  	}
   789  	if c.shouldConnect() {
   790  		if debugPrint {
   791  			fmt.Println("expected to be client side, not server side")
   792  		}
   793  		return errors.New("grid: expected to be client side, not server side")
   794  	}
   795  	msg := message{
   796  		Op: OpConnectResponse,
   797  	}
   798  
   799  	resp := connectResp{
   800  		ID:       c.id,
   801  		Accepted: true,
   802  	}
   803  	err := c.sendMsg(conn, msg, &resp)
   804  	if debugPrint {
   805  		fmt.Printf("grid: Queued Response %+v Side: %v\n", resp, c.side)
   806  	}
   807  	if err != nil {
   808  		return err
   809  	}
   810  	// Signal that we are reconnected, update state and handle messages.
   811  	// Prevent other connections from connecting while we process.
   812  	c.reconnectMu.Lock()
   813  	if c.remoteID != nil {
   814  		c.reconnected()
   815  	}
   816  	rid := uuid.UUID(req.ID)
   817  	c.remoteID = &rid
   818  
   819  	// Handle incoming messages until disconnect.
   820  	c.handleMessages(ctx, conn)
   821  	return nil
   822  }
   823  
   824  // reconnected signals the connection has been reconnected.
   825  // It will close all active requests and streams.
   826  // caller *must* hold reconnectMu.
   827  func (c *Connection) reconnected() {
   828  	c.updateState(StateConnectionError)
   829  	// Close all active requests.
   830  	if debugReqs {
   831  		fmt.Println(c.String(), "Reconnected. Clearing outgoing.")
   832  	}
   833  	c.outgoing.Range(func(key uint64, client *muxClient) bool {
   834  		client.close()
   835  		return true
   836  	})
   837  	c.inStream.Range(func(key uint64, value *muxServer) bool {
   838  		value.close()
   839  		return true
   840  	})
   841  
   842  	c.inStream.Clear()
   843  	c.outgoing.Clear()
   844  
   845  	// Wait for existing to exit
   846  	c.handleMsgWg.Wait()
   847  }
   848  
   849  func (c *Connection) updateState(s State) {
   850  	c.connChange.L.Lock()
   851  	defer c.connChange.L.Unlock()
   852  
   853  	// We may have reads that aren't locked, so update atomically.
   854  	gotState := atomic.LoadUint32((*uint32)(&c.state))
   855  	if gotState == StateShutdown || State(gotState) == s {
   856  		return
   857  	}
   858  	if s == StateConnected {
   859  		atomic.StoreInt64(&c.LastPong, time.Now().UnixNano())
   860  	}
   861  	atomic.StoreUint32((*uint32)(&c.state), uint32(s))
   862  	if debugPrint {
   863  		fmt.Println(c.Local, "updateState:", gotState, "->", s)
   864  	}
   865  	c.connChange.Broadcast()
   866  }
   867  
   868  // monitorState will monitor the state of the connection and close the net.Conn if it changes.
   869  func (c *Connection) monitorState(conn net.Conn, cancel context.CancelCauseFunc) {
   870  	c.connChange.L.Lock()
   871  	defer c.connChange.L.Unlock()
   872  	for {
   873  		newState := c.State()
   874  		if newState != StateConnected {
   875  			conn.Close()
   876  			cancel(ErrDisconnected)
   877  			return
   878  		}
   879  		// Unlock and wait for state change.
   880  		c.connChange.Wait()
   881  	}
   882  }
   883  
   884  // handleMessages will handle incoming messages on conn.
   885  // caller *must* hold reconnectMu.
   886  func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
   887  	c.updateState(StateConnected)
   888  	ctx, cancel := context.WithCancelCause(ctx)
   889  	defer cancel(ErrDisconnected)
   890  
   891  	// This will ensure that is something asks to disconnect and we are blocked on reads/writes
   892  	// the connection will be closed and readers/writers will unblock.
   893  	go c.monitorState(conn, cancel)
   894  
   895  	c.handleMsgWg.Add(2)
   896  	c.reconnectMu.Unlock()
   897  
   898  	// Read goroutine
   899  	go func() {
   900  		defer func() {
   901  			if rec := recover(); rec != nil {
   902  				logger.LogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec))
   903  				debug.PrintStack()
   904  			}
   905  			c.connChange.L.Lock()
   906  			if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) {
   907  				c.connChange.Broadcast()
   908  			}
   909  			c.connChange.L.Unlock()
   910  			conn.Close()
   911  			c.handleMsgWg.Done()
   912  		}()
   913  
   914  		controlHandler := wsutil.ControlFrameHandler(conn, c.side)
   915  		wsReader := wsutil.Reader{
   916  			Source:          conn,
   917  			State:           c.side,
   918  			CheckUTF8:       true,
   919  			SkipHeaderCheck: false,
   920  			OnIntermediate:  controlHandler,
   921  		}
   922  		readDataInto := func(dst []byte, rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, error) {
   923  			dst = dst[:0]
   924  			for {
   925  				hdr, err := wsReader.NextFrame()
   926  				if err != nil {
   927  					return nil, err
   928  				}
   929  				if hdr.OpCode.IsControl() {
   930  					if err := controlHandler(hdr, &wsReader); err != nil {
   931  						return nil, err
   932  					}
   933  					continue
   934  				}
   935  				if hdr.OpCode&want == 0 {
   936  					if err := wsReader.Discard(); err != nil {
   937  						return nil, err
   938  					}
   939  					continue
   940  				}
   941  				if int64(cap(dst)) < hdr.Length+1 {
   942  					dst = make([]byte, 0, hdr.Length+hdr.Length>>3)
   943  				}
   944  				return readAllInto(dst[:0], &wsReader)
   945  			}
   946  		}
   947  
   948  		// Keep reusing the same buffer.
   949  		var msg []byte
   950  		for {
   951  			if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected {
   952  				cancel(ErrDisconnected)
   953  				return
   954  			}
   955  			if cap(msg) > readBufferSize*4 {
   956  				// Don't keep too much memory around.
   957  				msg = nil
   958  			}
   959  
   960  			var err error
   961  			msg, err = readDataInto(msg, conn, c.side, ws.OpBinary)
   962  			if err != nil {
   963  				cancel(ErrDisconnected)
   964  				logger.LogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF)
   965  				return
   966  			}
   967  			if c.incomingBytes != nil {
   968  				c.incomingBytes(int64(len(msg)))
   969  			}
   970  
   971  			// Parse the received message
   972  			var m message
   973  			subID, remain, err := m.parse(msg)
   974  			if err != nil {
   975  				logger.LogIf(ctx, fmt.Errorf("ws parse package: %w", err))
   976  				cancel(ErrDisconnected)
   977  				return
   978  			}
   979  			if debugPrint {
   980  				fmt.Printf("%s Got msg: %v\n", c.Local, m)
   981  			}
   982  			if m.Op != OpMerged {
   983  				c.handleMsg(ctx, m, subID)
   984  				continue
   985  			}
   986  			// Handle merged messages.
   987  			messages := int(m.Seq)
   988  			for i := 0; i < messages; i++ {
   989  				if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected {
   990  					cancel(ErrDisconnected)
   991  					return
   992  				}
   993  				var next []byte
   994  				next, remain, err = msgp.ReadBytesZC(remain)
   995  				if err != nil {
   996  					logger.LogIf(ctx, fmt.Errorf("ws read merged: %w", err))
   997  					cancel(ErrDisconnected)
   998  					return
   999  				}
  1000  
  1001  				m.Payload = nil
  1002  				subID, _, err = m.parse(next)
  1003  				if err != nil {
  1004  					logger.LogIf(ctx, fmt.Errorf("ws parse merged: %w", err))
  1005  					cancel(ErrDisconnected)
  1006  					return
  1007  				}
  1008  				c.handleMsg(ctx, m, subID)
  1009  			}
  1010  		}
  1011  	}()
  1012  
  1013  	// Write function.
  1014  	defer func() {
  1015  		if rec := recover(); rec != nil {
  1016  			logger.LogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec))
  1017  			debug.PrintStack()
  1018  		}
  1019  		if debugPrint {
  1020  			fmt.Println("handleMessages: write goroutine exited")
  1021  		}
  1022  		cancel(ErrDisconnected)
  1023  		c.connChange.L.Lock()
  1024  		if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) {
  1025  			c.connChange.Broadcast()
  1026  		}
  1027  		c.disconnected()
  1028  		c.connChange.L.Unlock()
  1029  
  1030  		conn.Close()
  1031  		c.handleMsgWg.Done()
  1032  	}()
  1033  
  1034  	c.connMu.Lock()
  1035  	connPingInterval := c.connPingInterval
  1036  	c.connMu.Unlock()
  1037  	ping := time.NewTicker(connPingInterval)
  1038  	pingFrame := message{
  1039  		Op:         OpPing,
  1040  		DeadlineMS: 5000,
  1041  	}
  1042  
  1043  	defer ping.Stop()
  1044  	queue := make([][]byte, 0, maxMergeMessages)
  1045  	merged := make([]byte, 0, writeBufferSize)
  1046  	var queueSize int
  1047  	var buf bytes.Buffer
  1048  	var wsw wsWriter
  1049  	for {
  1050  		var toSend []byte
  1051  		select {
  1052  		case <-ctx.Done():
  1053  			return
  1054  		case <-ping.C:
  1055  			if c.State() != StateConnected {
  1056  				continue
  1057  			}
  1058  			lastPong := atomic.LoadInt64(&c.LastPong)
  1059  			if lastPong > 0 {
  1060  				lastPongTime := time.Unix(lastPong, 0)
  1061  				if d := time.Since(lastPongTime); d > connPingInterval*2 {
  1062  					logger.LogIf(ctx, fmt.Errorf("host %s last pong too old (%v); disconnecting", c.Remote, d.Round(time.Millisecond)))
  1063  					return
  1064  				}
  1065  			}
  1066  			var err error
  1067  			toSend, err = pingFrame.MarshalMsg(GetByteBuffer()[:0])
  1068  			if err != nil {
  1069  				logger.LogIf(ctx, err)
  1070  				// Fake it...
  1071  				atomic.StoreInt64(&c.LastPong, time.Now().Unix())
  1072  				continue
  1073  			}
  1074  		case toSend = <-c.outQueue:
  1075  			if len(toSend) == 0 {
  1076  				continue
  1077  			}
  1078  		}
  1079  		if len(queue) < maxMergeMessages && queueSize+len(toSend) < writeBufferSize-1024 && len(c.outQueue) > 0 {
  1080  			queue = append(queue, toSend)
  1081  			queueSize += len(toSend)
  1082  			continue
  1083  		}
  1084  		c.connChange.L.Lock()
  1085  		for {
  1086  			state := c.State()
  1087  			if state == StateConnected {
  1088  				break
  1089  			}
  1090  			if debugPrint {
  1091  				fmt.Println(c.Local, "Waiting for connection ->", c.Remote, "state: ", state)
  1092  			}
  1093  			if state == StateShutdown || state == StateConnectionError {
  1094  				c.connChange.L.Unlock()
  1095  				return
  1096  			}
  1097  			c.connChange.Wait()
  1098  			select {
  1099  			case <-ctx.Done():
  1100  				c.connChange.L.Unlock()
  1101  				return
  1102  			default:
  1103  			}
  1104  		}
  1105  		c.connChange.L.Unlock()
  1106  		if len(queue) == 0 {
  1107  			// Combine writes.
  1108  			buf.Reset()
  1109  			err := wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend)
  1110  			if err != nil {
  1111  				logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err))
  1112  				return
  1113  			}
  1114  			PutByteBuffer(toSend)
  1115  			err = conn.SetWriteDeadline(time.Now().Add(connWriteTimeout))
  1116  			if err != nil {
  1117  				logger.LogIf(ctx, fmt.Errorf("conn.SetWriteDeadline: %w", err))
  1118  				return
  1119  			}
  1120  			_, err = buf.WriteTo(conn)
  1121  			if err != nil {
  1122  				logger.LogIf(ctx, fmt.Errorf("ws write: %w", err))
  1123  				return
  1124  			}
  1125  			continue
  1126  		}
  1127  
  1128  		// Merge entries and send
  1129  		queue = append(queue, toSend)
  1130  		if debugPrint {
  1131  			fmt.Println("Merging", len(queue), "messages")
  1132  		}
  1133  
  1134  		toSend = merged[:0]
  1135  		m := message{Op: OpMerged, Seq: uint32(len(queue))}
  1136  		var err error
  1137  		toSend, err = m.MarshalMsg(toSend)
  1138  		if err != nil {
  1139  			logger.LogIf(ctx, fmt.Errorf("msg.MarshalMsg: %w", err))
  1140  			return
  1141  		}
  1142  		// Append as byte slices.
  1143  		for _, q := range queue {
  1144  			toSend = msgp.AppendBytes(toSend, q)
  1145  			PutByteBuffer(q)
  1146  		}
  1147  		queue = queue[:0]
  1148  		queueSize = 0
  1149  
  1150  		// Combine writes.
  1151  		// Consider avoiding buffer copy.
  1152  		buf.Reset()
  1153  		err = wsw.writeMessage(&buf, c.side, ws.OpBinary, toSend)
  1154  		if err != nil {
  1155  			logger.LogIf(ctx, fmt.Errorf("ws writeMessage: %w", err))
  1156  			return
  1157  		}
  1158  		// buf is our local buffer, so we can reuse it.
  1159  		err = conn.SetWriteDeadline(time.Now().Add(connWriteTimeout))
  1160  		if err != nil {
  1161  			logger.LogIf(ctx, fmt.Errorf("conn.SetWriteDeadline: %w", err))
  1162  			return
  1163  		}
  1164  		_, err = buf.WriteTo(conn)
  1165  		if err != nil {
  1166  			logger.LogIf(ctx, fmt.Errorf("ws write: %w", err))
  1167  			return
  1168  		}
  1169  
  1170  		if buf.Cap() > writeBufferSize*4 {
  1171  			// Reset buffer if it gets too big, so we don't keep it around.
  1172  			buf = bytes.Buffer{}
  1173  		}
  1174  	}
  1175  }
  1176  
  1177  func (c *Connection) handleMsg(ctx context.Context, m message, subID *subHandlerID) {
  1178  	switch m.Op {
  1179  	case OpMuxServerMsg:
  1180  		c.handleMuxServerMsg(ctx, m)
  1181  	case OpResponse:
  1182  		c.handleResponse(m)
  1183  	case OpMuxClientMsg:
  1184  		c.handleMuxClientMsg(ctx, m)
  1185  	case OpUnblockSrvMux:
  1186  		c.handleUnblockSrvMux(m)
  1187  	case OpUnblockClMux:
  1188  		c.handleUnblockClMux(m)
  1189  	case OpDisconnectServerMux:
  1190  		c.handleDisconnectServerMux(m)
  1191  	case OpDisconnectClientMux:
  1192  		c.handleDisconnectClientMux(m)
  1193  	case OpPing:
  1194  		c.handlePing(ctx, m)
  1195  	case OpPong:
  1196  		c.handlePong(ctx, m)
  1197  	case OpRequest:
  1198  		c.handleRequest(ctx, m, subID)
  1199  	case OpAckMux:
  1200  		c.handleAckMux(ctx, m)
  1201  	case OpConnectMux:
  1202  		c.handleConnectMux(ctx, m, subID)
  1203  	case OpMuxConnectError:
  1204  		c.handleConnectMuxError(ctx, m)
  1205  	default:
  1206  		logger.LogIf(ctx, fmt.Errorf("unknown message type: %v", m.Op))
  1207  	}
  1208  }
  1209  
  1210  func (c *Connection) handleConnectMux(ctx context.Context, m message, subID *subHandlerID) {
  1211  	// Stateless stream:
  1212  	if m.Flags&FlagStateless != 0 {
  1213  		// Reject for now, so we can safely add it later.
  1214  		if true {
  1215  			logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Stateless streams not supported"}))
  1216  			return
  1217  		}
  1218  
  1219  		var handler *StatelessHandler
  1220  		if subID == nil {
  1221  			handler = c.handlers.stateless[m.Handler]
  1222  		} else {
  1223  			handler = c.handlers.subStateless[*subID]
  1224  		}
  1225  		if handler == nil {
  1226  			logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"}))
  1227  			return
  1228  		}
  1229  		_, _ = c.inStream.LoadOrCompute(m.MuxID, func() *muxServer {
  1230  			return newMuxStateless(ctx, m, c, *handler)
  1231  		})
  1232  	} else {
  1233  		// Stream:
  1234  		var handler *StreamHandler
  1235  		if subID == nil {
  1236  			if !m.Handler.valid() {
  1237  				logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler"}))
  1238  				return
  1239  			}
  1240  			handler = c.handlers.streams[m.Handler]
  1241  		} else {
  1242  			handler = c.handlers.subStreams[*subID]
  1243  		}
  1244  		if handler == nil {
  1245  			logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"}))
  1246  			return
  1247  		}
  1248  
  1249  		// Start a new server handler if none exists.
  1250  		_, _ = c.inStream.LoadOrCompute(m.MuxID, func() *muxServer {
  1251  			return newMuxStream(ctx, m, c, *handler)
  1252  		})
  1253  	}
  1254  }
  1255  
  1256  // handleConnectMuxError when mux connect was rejected.
  1257  func (c *Connection) handleConnectMuxError(ctx context.Context, m message) {
  1258  	if v, ok := c.outgoing.Load(m.MuxID); ok {
  1259  		var cErr muxConnectError
  1260  		_, err := cErr.UnmarshalMsg(m.Payload)
  1261  		logger.LogIf(ctx, err)
  1262  		v.error(RemoteErr(cErr.Error))
  1263  		return
  1264  	}
  1265  	PutByteBuffer(m.Payload)
  1266  }
  1267  
  1268  func (c *Connection) handleAckMux(ctx context.Context, m message) {
  1269  	PutByteBuffer(m.Payload)
  1270  	v, ok := c.outgoing.Load(m.MuxID)
  1271  	if !ok {
  1272  		if m.Flags&FlagEOF == 0 {
  1273  			logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil))
  1274  		}
  1275  		return
  1276  	}
  1277  	if debugPrint {
  1278  		fmt.Println(c.Local, "Mux", m.MuxID, "Acknowledged")
  1279  	}
  1280  	v.ack(m.Seq)
  1281  }
  1282  
  1283  func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHandlerID) {
  1284  	if !m.Handler.valid() {
  1285  		logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler"}))
  1286  		return
  1287  	}
  1288  	if debugReqs {
  1289  		fmt.Println(m.MuxID, c.StringReverse(), "INCOMING")
  1290  	}
  1291  	// Singleshot message
  1292  	var handler SingleHandlerFn
  1293  	if subID == nil {
  1294  		handler = c.handlers.single[m.Handler]
  1295  	} else {
  1296  		handler = c.handlers.subSingle[*subID]
  1297  	}
  1298  	if handler == nil {
  1299  		logger.LogIf(ctx, c.queueMsg(m, muxConnectError{Error: "Invalid Handler for type"}))
  1300  		return
  1301  	}
  1302  
  1303  	// TODO: This causes allocations, but escape analysis doesn't really show the cause.
  1304  	// If another faithful engineer wants to take a stab, feel free.
  1305  	go func(m message) {
  1306  		var start time.Time
  1307  		if m.DeadlineMS > 0 {
  1308  			start = time.Now()
  1309  		}
  1310  		var b []byte
  1311  		var err *RemoteErr
  1312  		func() {
  1313  			defer func() {
  1314  				if rec := recover(); rec != nil {
  1315  					err = NewRemoteErrString(fmt.Sprintf("handleMessages: panic recovered: %v", rec))
  1316  					debug.PrintStack()
  1317  					logger.LogIf(ctx, err)
  1318  				}
  1319  			}()
  1320  			b, err = handler(m.Payload)
  1321  			if debugPrint {
  1322  				fmt.Println(c.Local, "Handler returned payload:", bytesOrLength(b), "err:", err)
  1323  			}
  1324  		}()
  1325  
  1326  		if m.DeadlineMS > 0 && time.Since(start).Milliseconds()+c.addDeadline.Milliseconds() > int64(m.DeadlineMS) {
  1327  			if debugReqs {
  1328  				fmt.Println(m.MuxID, c.StringReverse(), "DEADLINE EXCEEDED")
  1329  			}
  1330  			// No need to return result
  1331  			PutByteBuffer(b)
  1332  			return
  1333  		}
  1334  		if debugReqs {
  1335  			fmt.Println(m.MuxID, c.StringReverse(), "RESPONDING")
  1336  		}
  1337  		m = message{
  1338  			MuxID: m.MuxID,
  1339  			Seq:   m.Seq,
  1340  			Op:    OpResponse,
  1341  			Flags: FlagEOF,
  1342  		}
  1343  		if err != nil {
  1344  			m.Flags |= FlagPayloadIsErr
  1345  			m.Payload = []byte(*err)
  1346  		} else {
  1347  			m.Payload = b
  1348  			m.setZeroPayloadFlag()
  1349  		}
  1350  		logger.LogIf(ctx, c.queueMsg(m, nil))
  1351  	}(m)
  1352  }
  1353  
  1354  func (c *Connection) handlePong(ctx context.Context, m message) {
  1355  	var pong pongMsg
  1356  	_, err := pong.UnmarshalMsg(m.Payload)
  1357  	PutByteBuffer(m.Payload)
  1358  	logger.LogIf(ctx, err)
  1359  	if m.MuxID == 0 {
  1360  		atomic.StoreInt64(&c.LastPong, time.Now().Unix())
  1361  		return
  1362  	}
  1363  	if v, ok := c.outgoing.Load(m.MuxID); ok {
  1364  		v.pong(pong)
  1365  	} else {
  1366  		// We don't care if the client was removed in the meantime,
  1367  		// but we send a disconnect message to the server just in case.
  1368  		logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil))
  1369  	}
  1370  }
  1371  
  1372  func (c *Connection) handlePing(ctx context.Context, m message) {
  1373  	if m.MuxID == 0 {
  1374  		logger.LogIf(ctx, c.queueMsg(m, &pongMsg{}))
  1375  		return
  1376  	}
  1377  	// Single calls do not support pinging.
  1378  	if v, ok := c.inStream.Load(m.MuxID); ok {
  1379  		pong := v.ping(m.Seq)
  1380  		logger.LogIf(ctx, c.queueMsg(m, &pong))
  1381  	} else {
  1382  		pong := pongMsg{NotFound: true}
  1383  		logger.LogIf(ctx, c.queueMsg(m, &pong))
  1384  	}
  1385  	return
  1386  }
  1387  
  1388  func (c *Connection) handleDisconnectClientMux(m message) {
  1389  	if v, ok := c.outgoing.Load(m.MuxID); ok {
  1390  		if m.Flags&FlagPayloadIsErr != 0 {
  1391  			v.error(RemoteErr(m.Payload))
  1392  		} else {
  1393  			v.error(ErrDisconnected)
  1394  		}
  1395  		return
  1396  	}
  1397  	PutByteBuffer(m.Payload)
  1398  }
  1399  
  1400  func (c *Connection) handleDisconnectServerMux(m message) {
  1401  	if debugPrint {
  1402  		fmt.Println(c.Local, "Disconnect server mux:", m.MuxID)
  1403  	}
  1404  	PutByteBuffer(m.Payload)
  1405  	m.Payload = nil
  1406  	if v, ok := c.inStream.Load(m.MuxID); ok {
  1407  		v.close()
  1408  	}
  1409  }
  1410  
  1411  func (c *Connection) handleUnblockClMux(m message) {
  1412  	PutByteBuffer(m.Payload)
  1413  	m.Payload = nil
  1414  	v, ok := c.outgoing.Load(m.MuxID)
  1415  	if !ok {
  1416  		if debugPrint {
  1417  			fmt.Println(c.Local, "Unblock: Unknown Mux:", m.MuxID)
  1418  		}
  1419  		// We can expect to receive unblocks for closed muxes
  1420  		return
  1421  	}
  1422  	v.unblockSend(m.Seq)
  1423  }
  1424  
  1425  func (c *Connection) handleUnblockSrvMux(m message) {
  1426  	if m.Payload != nil {
  1427  		PutByteBuffer(m.Payload)
  1428  	}
  1429  	m.Payload = nil
  1430  	if v, ok := c.inStream.Load(m.MuxID); ok {
  1431  		v.unblockSend(m.Seq)
  1432  		return
  1433  	}
  1434  	// We can expect to receive unblocks for closed muxes
  1435  	if debugPrint {
  1436  		fmt.Println(c.Local, "Unblock: Unknown Mux:", m.MuxID)
  1437  	}
  1438  }
  1439  
  1440  func (c *Connection) handleMuxClientMsg(ctx context.Context, m message) {
  1441  	v, ok := c.inStream.Load(m.MuxID)
  1442  	if !ok {
  1443  		if debugPrint {
  1444  			fmt.Println(c.Local, "OpMuxClientMsg: Unknown Mux:", m.MuxID)
  1445  		}
  1446  		logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil))
  1447  		PutByteBuffer(m.Payload)
  1448  		return
  1449  	}
  1450  	v.message(m)
  1451  }
  1452  
  1453  func (c *Connection) handleResponse(m message) {
  1454  	if debugPrint {
  1455  		fmt.Printf("%s Got mux response: %v\n", c.Local, m)
  1456  	}
  1457  	v, ok := c.outgoing.Load(m.MuxID)
  1458  	if !ok {
  1459  		if debugReqs {
  1460  			fmt.Println(m.MuxID, c.String(), "Got response for unknown mux")
  1461  		}
  1462  		PutByteBuffer(m.Payload)
  1463  		return
  1464  	}
  1465  	if m.Flags&FlagPayloadIsErr != 0 {
  1466  		v.response(m.Seq, Response{
  1467  			Msg: nil,
  1468  			Err: RemoteErr(m.Payload),
  1469  		})
  1470  		PutByteBuffer(m.Payload)
  1471  	} else {
  1472  		v.response(m.Seq, Response{
  1473  			Msg: m.Payload,
  1474  			Err: nil,
  1475  		})
  1476  	}
  1477  	v.close()
  1478  	if debugReqs {
  1479  		fmt.Println(m.MuxID, c.String(), "handleResponse: closing mux")
  1480  	}
  1481  }
  1482  
  1483  func (c *Connection) handleMuxServerMsg(ctx context.Context, m message) {
  1484  	if debugPrint {
  1485  		fmt.Printf("%s Got mux msg: %v\n", c.Local, m)
  1486  	}
  1487  	v, ok := c.outgoing.Load(m.MuxID)
  1488  	if !ok {
  1489  		if m.Flags&FlagEOF == 0 {
  1490  			logger.LogIf(ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: m.MuxID}, nil))
  1491  		}
  1492  		PutByteBuffer(m.Payload)
  1493  		return
  1494  	}
  1495  	if m.Flags&FlagPayloadIsErr != 0 {
  1496  		v.response(m.Seq, Response{
  1497  			Msg: nil,
  1498  			Err: RemoteErr(m.Payload),
  1499  		})
  1500  		PutByteBuffer(m.Payload)
  1501  	} else if m.Payload != nil {
  1502  		v.response(m.Seq, Response{
  1503  			Msg: m.Payload,
  1504  			Err: nil,
  1505  		})
  1506  	}
  1507  	if m.Flags&FlagEOF != 0 {
  1508  		if v.cancelFn != nil && m.Flags&FlagPayloadIsErr == 0 {
  1509  			v.cancelFn(errStreamEOF)
  1510  		}
  1511  		v.close()
  1512  		if debugReqs {
  1513  			fmt.Println(m.MuxID, c.String(), "handleMuxServerMsg: DELETING MUX")
  1514  		}
  1515  		c.outgoing.Delete(m.MuxID)
  1516  	}
  1517  }
  1518  
  1519  func (c *Connection) deleteMux(incoming bool, muxID uint64) {
  1520  	if incoming {
  1521  		if debugPrint {
  1522  			fmt.Println("deleteMux: disconnect incoming mux", muxID)
  1523  		}
  1524  		v, loaded := c.inStream.LoadAndDelete(muxID)
  1525  		if loaded && v != nil {
  1526  			logger.LogIf(c.ctx, c.queueMsg(message{Op: OpDisconnectClientMux, MuxID: muxID}, nil))
  1527  			v.close()
  1528  		}
  1529  	} else {
  1530  		if debugPrint {
  1531  			fmt.Println("deleteMux: disconnect outgoing mux", muxID)
  1532  		}
  1533  		v, loaded := c.outgoing.LoadAndDelete(muxID)
  1534  		if loaded && v != nil {
  1535  			if debugReqs {
  1536  				fmt.Println(muxID, c.String(), "deleteMux: DELETING MUX")
  1537  			}
  1538  			v.close()
  1539  			logger.LogIf(c.ctx, c.queueMsg(message{Op: OpDisconnectServerMux, MuxID: muxID}, nil))
  1540  		}
  1541  	}
  1542  }
  1543  
  1544  // State returns the current connection status.
  1545  func (c *Connection) State() State {
  1546  	return State(atomic.LoadUint32((*uint32)(&c.state)))
  1547  }
  1548  
  1549  // Stats returns the current connection stats.
  1550  func (c *Connection) Stats() ConnectionStats {
  1551  	return ConnectionStats{
  1552  		IncomingStreams: c.inStream.Size(),
  1553  		OutgoingStreams: c.outgoing.Size(),
  1554  	}
  1555  }
  1556  
  1557  func (c *Connection) debugMsg(d debugMsg, args ...any) {
  1558  	if debugPrint {
  1559  		fmt.Println("debug: sending message", d, args)
  1560  	}
  1561  
  1562  	switch d {
  1563  	case debugShutdown:
  1564  		c.updateState(StateShutdown)
  1565  	case debugKillInbound:
  1566  		c.connMu.Lock()
  1567  		defer c.connMu.Unlock()
  1568  		if c.debugInConn != nil {
  1569  			if debugPrint {
  1570  				fmt.Println("debug: closing inbound connection")
  1571  			}
  1572  			c.debugInConn.Close()
  1573  		}
  1574  	case debugKillOutbound:
  1575  		c.connMu.Lock()
  1576  		defer c.connMu.Unlock()
  1577  		if c.debugInConn != nil {
  1578  			if debugPrint {
  1579  				fmt.Println("debug: closing outgoing connection")
  1580  			}
  1581  			c.debugInConn.Close()
  1582  		}
  1583  	case debugWaitForExit:
  1584  		c.reconnectMu.Lock()
  1585  		c.handleMsgWg.Wait()
  1586  		c.reconnectMu.Unlock()
  1587  	case debugSetConnPingDuration:
  1588  		c.connMu.Lock()
  1589  		defer c.connMu.Unlock()
  1590  		c.connPingInterval = args[0].(time.Duration)
  1591  	case debugSetClientPingDuration:
  1592  		c.clientPingInterval = args[0].(time.Duration)
  1593  	case debugAddToDeadline:
  1594  		c.addDeadline = args[0].(time.Duration)
  1595  	case debugIsOutgoingClosed:
  1596  		// params: muxID uint64, isClosed func(bool)
  1597  		muxID := args[0].(uint64)
  1598  		resp := args[1].(func(b bool))
  1599  		mid, ok := c.outgoing.Load(muxID)
  1600  		if !ok || mid == nil {
  1601  			resp(true)
  1602  			return
  1603  		}
  1604  		mid.respMu.Lock()
  1605  		resp(mid.closed)
  1606  		mid.respMu.Unlock()
  1607  	}
  1608  }
  1609  
  1610  // wsWriter writes websocket messages.
  1611  type wsWriter struct {
  1612  	tmp [ws.MaxHeaderSize]byte
  1613  }
  1614  
  1615  // writeMessage writes a message to w without allocations.
  1616  func (ww *wsWriter) writeMessage(w io.Writer, s ws.State, op ws.OpCode, p []byte) error {
  1617  	const fin = true
  1618  	var frame ws.Frame
  1619  	if s.ClientSide() {
  1620  		// We do not need to copy the payload, since we own it.
  1621  		payload := p
  1622  
  1623  		frame = ws.NewFrame(op, fin, payload)
  1624  		frame = ws.MaskFrameInPlace(frame)
  1625  	} else {
  1626  		frame = ws.NewFrame(op, fin, p)
  1627  	}
  1628  
  1629  	return ww.writeFrame(w, frame)
  1630  }
  1631  
  1632  // writeFrame writes frame binary representation into w.
  1633  func (ww *wsWriter) writeFrame(w io.Writer, f ws.Frame) error {
  1634  	const (
  1635  		bit0  = 0x80
  1636  		len7  = int64(125)
  1637  		len16 = int64(^(uint16(0)))
  1638  		len64 = int64(^(uint64(0)) >> 1)
  1639  	)
  1640  
  1641  	bts := ww.tmp[:]
  1642  	if f.Header.Fin {
  1643  		bts[0] |= bit0
  1644  	}
  1645  	bts[0] |= f.Header.Rsv << 4
  1646  	bts[0] |= byte(f.Header.OpCode)
  1647  
  1648  	var n int
  1649  	switch {
  1650  	case f.Header.Length <= len7:
  1651  		bts[1] = byte(f.Header.Length)
  1652  		n = 2
  1653  
  1654  	case f.Header.Length <= len16:
  1655  		bts[1] = 126
  1656  		binary.BigEndian.PutUint16(bts[2:4], uint16(f.Header.Length))
  1657  		n = 4
  1658  
  1659  	case f.Header.Length <= len64:
  1660  		bts[1] = 127
  1661  		binary.BigEndian.PutUint64(bts[2:10], uint64(f.Header.Length))
  1662  		n = 10
  1663  
  1664  	default:
  1665  		return ws.ErrHeaderLengthUnexpected
  1666  	}
  1667  
  1668  	if f.Header.Masked {
  1669  		bts[1] |= bit0
  1670  		n += copy(bts[n:], f.Header.Mask[:])
  1671  	}
  1672  
  1673  	if _, err := w.Write(bts[:n]); err != nil {
  1674  		return err
  1675  	}
  1676  
  1677  	_, err := w.Write(f.Payload)
  1678  	return err
  1679  }