github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/mux/mux.go (about)

     1  package mux
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  	"net"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    14  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/point"
    15  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol"
    16  	"github.com/Asutorufa/yuhaiin/pkg/utils/relay"
    17  	"github.com/libp2p/go-yamux/v4"
    18  )
    19  
    20  var config *yamux.Config
    21  
    22  func init() {
    23  	config = yamux.DefaultConfig()
    24  	// We've bumped this to 16MiB as this critically limits throughput.
    25  	//
    26  	// 1MiB means a best case of 10MiB/s (83.89Mbps) on a connection with
    27  	// 100ms latency. The default gave us 2.4MiB *best case* which was
    28  	// totally unacceptable.
    29  	config.MaxStreamWindowSize = uint32(16 * 1024 * 1024)
    30  	// don't spam
    31  	config.LogOutput = io.Discard
    32  	// We always run over a security transport that buffers internally
    33  	// (i.e., uses a block cipher).
    34  	config.ReadBufSize = 0
    35  	// Effectively disable the incoming streams limit.
    36  	// This is now dynamically limited by the resource manager.
    37  	config.MaxIncomingStreams = math.MaxUint32
    38  	// Disable keepalive, we don't need it
    39  	// tcp keepalive will used in underlying conn
    40  	config.EnableKeepAlive = false
    41  
    42  	config.ConnectionWriteTimeout = 4*time.Second + time.Second/2
    43  
    44  	relay.AppendIgnoreError(yamux.ErrStreamReset)
    45  }
    46  
    47  type connEntry struct {
    48  	mu      sync.Mutex
    49  	session *IdleSession
    50  }
    51  
    52  func (c *connEntry) Close() error {
    53  	c.mu.Lock()
    54  	defer c.mu.Unlock()
    55  
    56  	err := c.session.Close()
    57  	c.session = nil
    58  
    59  	return err
    60  }
    61  
    62  type MuxClient struct {
    63  	netapi.Proxy
    64  	selector *rangeSelector
    65  }
    66  
    67  func init() {
    68  	point.RegisterProtocol(NewClient)
    69  }
    70  
    71  func NewClient(config *protocol.Protocol_Mux) point.WrapProxy {
    72  	return func(dialer netapi.Proxy) (netapi.Proxy, error) {
    73  		if config.Mux.Concurrency <= 0 {
    74  			config.Mux.Concurrency = 1
    75  		}
    76  
    77  		c := &MuxClient{
    78  			Proxy:    dialer,
    79  			selector: NewRangeSelector(int(config.Mux.Concurrency)),
    80  		}
    81  
    82  		return c, nil
    83  	}
    84  }
    85  
    86  func (m *MuxClient) Conn(ctx context.Context, addr netapi.Address) (net.Conn, error) {
    87  	session, err := m.nextSession(ctx)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	conn, err := session.OpenStream(ctx)
    93  	if err != nil {
    94  		session.closed = true
    95  		return nil, fmt.Errorf("yamux open error: %w", err)
    96  	}
    97  
    98  	return &muxConn{conn}, nil
    99  }
   100  
   101  func (m *MuxClient) nextSession(ctx context.Context) (*IdleSession, error) {
   102  	entry := m.selector.Select()
   103  
   104  	session := entry.session
   105  
   106  	if session != nil && !session.IsClosed() {
   107  		return session, nil
   108  	}
   109  
   110  	entry.mu.Lock()
   111  	defer entry.mu.Unlock()
   112  
   113  	if entry.session != nil && !entry.session.IsClosed() {
   114  		return entry.session, nil
   115  	}
   116  
   117  	dc, err := m.Proxy.Conn(ctx, netapi.EmptyAddr)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	yamuxSession, err := yamux.Client(dc, config, nil)
   123  	if err != nil {
   124  		dc.Close()
   125  		return nil, fmt.Errorf("yamux client error: %w", err)
   126  	}
   127  
   128  	entry.session = NewIdleSession(yamuxSession, time.Minute)
   129  
   130  	return entry.session, nil
   131  }
   132  
   133  type IdleSession struct {
   134  	closed bool
   135  	*yamux.Session
   136  
   137  	lastStreamTime *atomic.Pointer[time.Time]
   138  }
   139  
   140  func NewIdleSession(session *yamux.Session, IdleTimeout time.Duration) *IdleSession {
   141  	s := &IdleSession{
   142  		Session:        session,
   143  		lastStreamTime: &atomic.Pointer[time.Time]{},
   144  	}
   145  
   146  	s.updateLatestStreamTime()
   147  
   148  	go func() {
   149  		readyClose := false
   150  		ticker := time.NewTicker(IdleTimeout)
   151  		defer ticker.Stop()
   152  
   153  		for {
   154  			select {
   155  			case <-session.CloseChan():
   156  				return
   157  			case <-ticker.C:
   158  				if session.NumStreams() != 0 {
   159  					readyClose = false
   160  					continue
   161  				}
   162  
   163  				if time.Since(*s.lastStreamTime.Load()) < IdleTimeout {
   164  					readyClose = false
   165  					continue
   166  				}
   167  
   168  				if readyClose {
   169  					session.Close()
   170  					return
   171  				}
   172  
   173  				readyClose = true
   174  			}
   175  		}
   176  	}()
   177  
   178  	return s
   179  }
   180  
   181  func (i *IdleSession) updateLatestStreamTime() {
   182  	now := time.Now()
   183  	i.lastStreamTime.Store(&now)
   184  }
   185  
   186  func (i *IdleSession) OpenStream(ctx context.Context) (*yamux.Stream, error) {
   187  	i.updateLatestStreamTime()
   188  	return i.Session.OpenStream(ctx)
   189  }
   190  
   191  func (i *IdleSession) Open(ctx context.Context) (net.Conn, error) {
   192  	i.updateLatestStreamTime()
   193  	return i.Session.Open(ctx)
   194  }
   195  
   196  func (i *IdleSession) IsClosed() bool {
   197  	if i.closed {
   198  		return true
   199  	}
   200  
   201  	return i.Session.IsClosed()
   202  }
   203  
   204  type MuxConn interface {
   205  	net.Conn
   206  	StreamID() uint32
   207  }
   208  
   209  type muxConn struct {
   210  	MuxConn // must not *yamux.Stream, the close write is not a really close write
   211  }
   212  
   213  func (m *muxConn) RemoteAddr() net.Addr {
   214  	return &MuxAddr{
   215  		Addr: m.MuxConn.RemoteAddr(),
   216  		ID:   m.StreamID(),
   217  	}
   218  }
   219  
   220  // func (m *muxConn) Read(p []byte) (n int, err error) {
   221  // 	n, err = m.MuxConn.Read(p)
   222  // 	if err != nil {
   223  // 		if errors.Is(err, yamux.ErrStreamReset) || errors.Is(err, yamux.ErrStreamClosed) {
   224  // 			err = io.EOF
   225  // 		}
   226  // 	}
   227  
   228  // 	return
   229  // }
   230  
   231  type MuxAddr struct {
   232  	Addr net.Addr
   233  	ID   uint32
   234  }
   235  
   236  func (q *MuxAddr) String() string  { return fmt.Sprintf("yamux://%d@%v", q.ID, q.Addr) }
   237  func (q *MuxAddr) Network() string { return "tcp" }
   238  
   239  type rangeSelector struct {
   240  	content []*connEntry
   241  	cap     uint64
   242  	count   atomic.Uint64
   243  }
   244  
   245  func NewRangeSelector(cap int) *rangeSelector {
   246  	content := make([]*connEntry, cap)
   247  
   248  	for i := 0; i < cap; i++ {
   249  		content[i] = &connEntry{}
   250  	}
   251  
   252  	return &rangeSelector{
   253  		content: content,
   254  		cap:     uint64(cap),
   255  	}
   256  }
   257  
   258  func (s *rangeSelector) Select() *connEntry {
   259  	return s.content[s.count.Add(1)%s.cap]
   260  }