github.com/sagernet/sing-mux@v0.2.1-0.20240124034317-9bfb33698bb6/session.go (about)

     1  package mux
     2  
     3  import (
     4  	"io"
     5  	"net"
     6  	"reflect"
     7  
     8  	E "github.com/sagernet/sing/common/exceptions"
     9  	"github.com/sagernet/smux"
    10  
    11  	"github.com/hashicorp/yamux"
    12  )
    13  
    14  type abstractSession interface {
    15  	Open() (net.Conn, error)
    16  	Accept() (net.Conn, error)
    17  	NumStreams() int
    18  	Close() error
    19  	IsClosed() bool
    20  	CanTakeNewRequest() bool
    21  }
    22  
    23  func newClientSession(conn net.Conn, protocol byte) (abstractSession, error) {
    24  	switch protocol {
    25  	case ProtocolH2Mux:
    26  		session, err := newH2MuxClient(conn)
    27  		if err != nil {
    28  			return nil, err
    29  		}
    30  		return session, nil
    31  	case ProtocolSmux:
    32  		client, err := smux.Client(conn, smuxConfig())
    33  		if err != nil {
    34  			return nil, err
    35  		}
    36  		return &smuxSession{client}, nil
    37  	case ProtocolYAMux:
    38  		checkYAMuxConn(conn)
    39  		client, err := yamux.Client(conn, yaMuxConfig())
    40  		if err != nil {
    41  			return nil, err
    42  		}
    43  		return &yamuxSession{client}, nil
    44  	default:
    45  		return nil, E.New("unexpected protocol ", protocol)
    46  	}
    47  }
    48  
    49  func newServerSession(conn net.Conn, protocol byte) (abstractSession, error) {
    50  	switch protocol {
    51  	case ProtocolH2Mux:
    52  		return newH2MuxServer(conn), nil
    53  	case ProtocolSmux:
    54  		client, err := smux.Server(conn, smuxConfig())
    55  		if err != nil {
    56  			return nil, err
    57  		}
    58  		return &smuxSession{client}, nil
    59  	case ProtocolYAMux:
    60  		checkYAMuxConn(conn)
    61  		client, err := yamux.Server(conn, yaMuxConfig())
    62  		if err != nil {
    63  			return nil, err
    64  		}
    65  		return &yamuxSession{client}, nil
    66  	default:
    67  		return nil, E.New("unexpected protocol ", protocol)
    68  	}
    69  }
    70  
    71  func checkYAMuxConn(conn net.Conn) {
    72  	if conn.LocalAddr() == nil || conn.RemoteAddr() == nil {
    73  		panic("found net.Conn with nil addr: " + reflect.TypeOf(conn).String())
    74  	}
    75  }
    76  
    77  var _ abstractSession = (*smuxSession)(nil)
    78  
    79  type smuxSession struct {
    80  	*smux.Session
    81  }
    82  
    83  func (s *smuxSession) Open() (net.Conn, error) {
    84  	return s.OpenStream()
    85  }
    86  
    87  func (s *smuxSession) Accept() (net.Conn, error) {
    88  	return s.AcceptStream()
    89  }
    90  
    91  func (s *smuxSession) CanTakeNewRequest() bool {
    92  	return true
    93  }
    94  
    95  type yamuxSession struct {
    96  	*yamux.Session
    97  }
    98  
    99  func (y *yamuxSession) CanTakeNewRequest() bool {
   100  	return true
   101  }
   102  
   103  func smuxConfig() *smux.Config {
   104  	config := smux.DefaultConfig()
   105  	config.KeepAliveDisabled = true
   106  	return config
   107  }
   108  
   109  func yaMuxConfig() *yamux.Config {
   110  	config := yamux.DefaultConfig()
   111  	config.LogOutput = io.Discard
   112  	config.StreamCloseTimeout = TCPTimeout
   113  	config.StreamOpenTimeout = TCPTimeout
   114  	return config
   115  }