github.com/sagernet/sing-box@v1.2.7/common/mux/protocol.go (about)

     1  package mux
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  	"net"
     7  
     8  	C "github.com/sagernet/sing-box/constant"
     9  	"github.com/sagernet/sing/common"
    10  	"github.com/sagernet/sing/common/buf"
    11  	E "github.com/sagernet/sing/common/exceptions"
    12  	M "github.com/sagernet/sing/common/metadata"
    13  	N "github.com/sagernet/sing/common/network"
    14  	"github.com/sagernet/sing/common/rw"
    15  	"github.com/sagernet/smux"
    16  
    17  	"github.com/hashicorp/yamux"
    18  )
    19  
    20  var Destination = M.Socksaddr{
    21  	Fqdn: "sp.mux.sing-box.arpa",
    22  	Port: 444,
    23  }
    24  
    25  const (
    26  	ProtocolSMux Protocol = iota
    27  	ProtocolYAMux
    28  )
    29  
    30  type Protocol byte
    31  
    32  func ParseProtocol(name string) (Protocol, error) {
    33  	switch name {
    34  	case "", "smux":
    35  		return ProtocolSMux, nil
    36  	case "yamux":
    37  		return ProtocolYAMux, nil
    38  	default:
    39  		return ProtocolYAMux, E.New("unknown multiplex protocol: ", name)
    40  	}
    41  }
    42  
    43  func (p Protocol) newServer(conn net.Conn) (abstractSession, error) {
    44  	switch p {
    45  	case ProtocolSMux:
    46  		session, err := smux.Server(conn, smuxConfig())
    47  		if err != nil {
    48  			return nil, err
    49  		}
    50  		return &smuxSession{session}, nil
    51  	case ProtocolYAMux:
    52  		return yamux.Server(conn, yaMuxConfig())
    53  	default:
    54  		panic("unknown protocol")
    55  	}
    56  }
    57  
    58  func (p Protocol) newClient(conn net.Conn) (abstractSession, error) {
    59  	switch p {
    60  	case ProtocolSMux:
    61  		session, err := smux.Client(conn, smuxConfig())
    62  		if err != nil {
    63  			return nil, err
    64  		}
    65  		return &smuxSession{session}, nil
    66  	case ProtocolYAMux:
    67  		return yamux.Client(conn, yaMuxConfig())
    68  	default:
    69  		panic("unknown protocol")
    70  	}
    71  }
    72  
    73  func smuxConfig() *smux.Config {
    74  	config := smux.DefaultConfig()
    75  	config.KeepAliveDisabled = true
    76  	return config
    77  }
    78  
    79  func yaMuxConfig() *yamux.Config {
    80  	config := yamux.DefaultConfig()
    81  	config.LogOutput = io.Discard
    82  	config.StreamCloseTimeout = C.TCPTimeout
    83  	config.StreamOpenTimeout = C.TCPTimeout
    84  	return config
    85  }
    86  
    87  func (p Protocol) String() string {
    88  	switch p {
    89  	case ProtocolSMux:
    90  		return "smux"
    91  	case ProtocolYAMux:
    92  		return "yamux"
    93  	default:
    94  		return "unknown"
    95  	}
    96  }
    97  
    98  const (
    99  	version0 = 0
   100  )
   101  
   102  type Request struct {
   103  	Protocol Protocol
   104  }
   105  
   106  func ReadRequest(reader io.Reader) (*Request, error) {
   107  	version, err := rw.ReadByte(reader)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	if version != version0 {
   112  		return nil, E.New("unsupported version: ", version)
   113  	}
   114  	protocol, err := rw.ReadByte(reader)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	if protocol > byte(ProtocolYAMux) {
   119  		return nil, E.New("unsupported protocol: ", protocol)
   120  	}
   121  	return &Request{Protocol: Protocol(protocol)}, nil
   122  }
   123  
   124  func EncodeRequest(buffer *buf.Buffer, request Request) {
   125  	buffer.WriteByte(version0)
   126  	buffer.WriteByte(byte(request.Protocol))
   127  }
   128  
   129  const (
   130  	flagUDP       = 1
   131  	flagAddr      = 2
   132  	statusSuccess = 0
   133  	statusError   = 1
   134  )
   135  
   136  type StreamRequest struct {
   137  	Network     string
   138  	Destination M.Socksaddr
   139  	PacketAddr  bool
   140  }
   141  
   142  func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) {
   143  	var flags uint16
   144  	err := binary.Read(reader, binary.BigEndian, &flags)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	var network string
   153  	var udpAddr bool
   154  	if flags&flagUDP == 0 {
   155  		network = N.NetworkTCP
   156  	} else {
   157  		network = N.NetworkUDP
   158  		udpAddr = flags&flagAddr != 0
   159  	}
   160  	return &StreamRequest{network, destination, udpAddr}, nil
   161  }
   162  
   163  func requestLen(request StreamRequest) int {
   164  	var rLen int
   165  	rLen += 1 // version
   166  	rLen += 2 // flags
   167  	rLen += M.SocksaddrSerializer.AddrPortLen(request.Destination)
   168  	return rLen
   169  }
   170  
   171  func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) {
   172  	destination := request.Destination
   173  	var flags uint16
   174  	if request.Network == N.NetworkUDP {
   175  		flags |= flagUDP
   176  	}
   177  	if request.PacketAddr {
   178  		flags |= flagAddr
   179  		if !destination.IsValid() {
   180  			destination = Destination
   181  		}
   182  	}
   183  	common.Must(
   184  		binary.Write(buffer, binary.BigEndian, flags),
   185  		M.SocksaddrSerializer.WriteAddrPort(buffer, destination),
   186  	)
   187  }
   188  
   189  type StreamResponse struct {
   190  	Status  uint8
   191  	Message string
   192  }
   193  
   194  func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) {
   195  	var response StreamResponse
   196  	status, err := rw.ReadByte(reader)
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  	response.Status = status
   201  	if status == statusError {
   202  		response.Message, err = rw.ReadVString(reader)
   203  		if err != nil {
   204  			return nil, err
   205  		}
   206  	}
   207  	return &response, nil
   208  }
   209  
   210  type wrapStream struct {
   211  	net.Conn
   212  }
   213  
   214  func (w *wrapStream) Read(p []byte) (n int, err error) {
   215  	n, err = w.Conn.Read(p)
   216  	err = wrapError(err)
   217  	return
   218  }
   219  
   220  func (w *wrapStream) Write(p []byte) (n int, err error) {
   221  	n, err = w.Conn.Write(p)
   222  	err = wrapError(err)
   223  	return
   224  }
   225  
   226  func (w *wrapStream) WriteIsThreadUnsafe() {
   227  }
   228  
   229  func (w *wrapStream) Upstream() any {
   230  	return w.Conn
   231  }
   232  
   233  func wrapError(err error) error {
   234  	switch err {
   235  	case yamux.ErrStreamClosed:
   236  		return io.EOF
   237  	default:
   238  		return err
   239  	}
   240  }