github.com/sagernet/sing-box@v1.2.7/transport/hysteria/protocol.go (about)

     1  package hysteria
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"io"
     7  	"math/rand"
     8  	"net"
     9  	"os"
    10  	"time"
    11  
    12  	"github.com/sagernet/quic-go"
    13  	"github.com/sagernet/sing/common"
    14  	"github.com/sagernet/sing/common/buf"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  )
    18  
    19  const (
    20  	MbpsToBps                      = 125000
    21  	MinSpeedBPS                    = 16384
    22  	DefaultStreamReceiveWindow     = 15728640 // 15 MB/s
    23  	DefaultConnectionReceiveWindow = 67108864 // 64 MB/s
    24  	DefaultMaxIncomingStreams      = 1024
    25  	DefaultALPN                    = "hysteria"
    26  	KeepAlivePeriod                = 10 * time.Second
    27  )
    28  
    29  const Version = 3
    30  
    31  type ClientHello struct {
    32  	SendBPS uint64
    33  	RecvBPS uint64
    34  	Auth    []byte
    35  }
    36  
    37  func WriteClientHello(stream io.Writer, hello ClientHello) error {
    38  	var requestLen int
    39  	requestLen += 1 // version
    40  	requestLen += 8 // sendBPS
    41  	requestLen += 8 // recvBPS
    42  	requestLen += 2 // auth len
    43  	requestLen += len(hello.Auth)
    44  	_request := buf.StackNewSize(requestLen)
    45  	defer common.KeepAlive(_request)
    46  	request := common.Dup(_request)
    47  	defer request.Release()
    48  	common.Must(
    49  		request.WriteByte(Version),
    50  		binary.Write(request, binary.BigEndian, hello.SendBPS),
    51  		binary.Write(request, binary.BigEndian, hello.RecvBPS),
    52  		binary.Write(request, binary.BigEndian, uint16(len(hello.Auth))),
    53  		common.Error(request.Write(hello.Auth)),
    54  	)
    55  	return common.Error(stream.Write(request.Bytes()))
    56  }
    57  
    58  func ReadClientHello(reader io.Reader) (*ClientHello, error) {
    59  	var version uint8
    60  	err := binary.Read(reader, binary.BigEndian, &version)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	if version != Version {
    65  		return nil, E.New("unsupported client version: ", version)
    66  	}
    67  	var clientHello ClientHello
    68  	err = binary.Read(reader, binary.BigEndian, &clientHello.SendBPS)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	err = binary.Read(reader, binary.BigEndian, &clientHello.RecvBPS)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	var authLen uint16
    77  	err = binary.Read(reader, binary.BigEndian, &authLen)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	clientHello.Auth = make([]byte, authLen)
    82  	_, err = io.ReadFull(reader, clientHello.Auth)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	return &clientHello, nil
    87  }
    88  
    89  type ServerHello struct {
    90  	OK      bool
    91  	SendBPS uint64
    92  	RecvBPS uint64
    93  	Message string
    94  }
    95  
    96  func ReadServerHello(stream io.Reader) (*ServerHello, error) {
    97  	var responseLen int
    98  	responseLen += 1 // ok
    99  	responseLen += 8 // sendBPS
   100  	responseLen += 8 // recvBPS
   101  	responseLen += 2 // message len
   102  	_response := buf.StackNewSize(responseLen)
   103  	defer common.KeepAlive(_response)
   104  	response := common.Dup(_response)
   105  	defer response.Release()
   106  	_, err := response.ReadFullFrom(stream, responseLen)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	var serverHello ServerHello
   111  	serverHello.OK = response.Byte(0) == 1
   112  	serverHello.SendBPS = binary.BigEndian.Uint64(response.Range(1, 9))
   113  	serverHello.RecvBPS = binary.BigEndian.Uint64(response.Range(9, 17))
   114  	messageLen := binary.BigEndian.Uint16(response.Range(17, 19))
   115  	if messageLen == 0 {
   116  		return &serverHello, nil
   117  	}
   118  	message := make([]byte, messageLen)
   119  	_, err = io.ReadFull(stream, message)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	serverHello.Message = string(message)
   124  	return &serverHello, nil
   125  }
   126  
   127  func WriteServerHello(stream io.Writer, hello ServerHello) error {
   128  	var responseLen int
   129  	responseLen += 1 // ok
   130  	responseLen += 8 // sendBPS
   131  	responseLen += 8 // recvBPS
   132  	responseLen += 2 // message len
   133  	responseLen += len(hello.Message)
   134  	_response := buf.StackNewSize(responseLen)
   135  	defer common.KeepAlive(_response)
   136  	response := common.Dup(_response)
   137  	defer response.Release()
   138  	if hello.OK {
   139  		common.Must(response.WriteByte(1))
   140  	} else {
   141  		common.Must(response.WriteByte(0))
   142  	}
   143  	common.Must(
   144  		binary.Write(response, binary.BigEndian, hello.SendBPS),
   145  		binary.Write(response, binary.BigEndian, hello.RecvBPS),
   146  		binary.Write(response, binary.BigEndian, uint16(len(hello.Message))),
   147  		common.Error(response.WriteString(hello.Message)),
   148  	)
   149  	return common.Error(stream.Write(response.Bytes()))
   150  }
   151  
   152  type ClientRequest struct {
   153  	UDP  bool
   154  	Host string
   155  	Port uint16
   156  }
   157  
   158  func ReadClientRequest(stream io.Reader) (*ClientRequest, error) {
   159  	var clientRequest ClientRequest
   160  	err := binary.Read(stream, binary.BigEndian, &clientRequest.UDP)
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  	var hostLen uint16
   165  	err = binary.Read(stream, binary.BigEndian, &hostLen)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  	host := make([]byte, hostLen)
   170  	_, err = io.ReadFull(stream, host)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  	clientRequest.Host = string(host)
   175  	err = binary.Read(stream, binary.BigEndian, &clientRequest.Port)
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  	return &clientRequest, nil
   180  }
   181  
   182  func WriteClientRequest(stream io.Writer, request ClientRequest) error {
   183  	var requestLen int
   184  	requestLen += 1 // udp
   185  	requestLen += 2 // host len
   186  	requestLen += len(request.Host)
   187  	requestLen += 2 // port
   188  	_buffer := buf.StackNewSize(requestLen)
   189  	defer common.KeepAlive(_buffer)
   190  	buffer := common.Dup(_buffer)
   191  	defer buffer.Release()
   192  	if request.UDP {
   193  		common.Must(buffer.WriteByte(1))
   194  	} else {
   195  		common.Must(buffer.WriteByte(0))
   196  	}
   197  	common.Must(
   198  		binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))),
   199  		common.Error(buffer.WriteString(request.Host)),
   200  		binary.Write(buffer, binary.BigEndian, request.Port),
   201  	)
   202  	return common.Error(stream.Write(buffer.Bytes()))
   203  }
   204  
   205  type ServerResponse struct {
   206  	OK           bool
   207  	UDPSessionID uint32
   208  	Message      string
   209  }
   210  
   211  func ReadServerResponse(stream io.Reader) (*ServerResponse, error) {
   212  	var responseLen int
   213  	responseLen += 1 // ok
   214  	responseLen += 4 // udp session id
   215  	responseLen += 2 // message len
   216  	_response := buf.StackNewSize(responseLen)
   217  	defer common.KeepAlive(_response)
   218  	response := common.Dup(_response)
   219  	defer response.Release()
   220  	_, err := response.ReadFullFrom(stream, responseLen)
   221  	if err != nil {
   222  		return nil, err
   223  	}
   224  	var serverResponse ServerResponse
   225  	serverResponse.OK = response.Byte(0) == 1
   226  	serverResponse.UDPSessionID = binary.BigEndian.Uint32(response.Range(1, 5))
   227  	messageLen := binary.BigEndian.Uint16(response.Range(5, 7))
   228  	if messageLen == 0 {
   229  		return &serverResponse, nil
   230  	}
   231  	message := make([]byte, messageLen)
   232  	_, err = io.ReadFull(stream, message)
   233  	if err != nil {
   234  		return nil, err
   235  	}
   236  	serverResponse.Message = string(message)
   237  	return &serverResponse, nil
   238  }
   239  
   240  func WriteServerResponse(stream io.Writer, response ServerResponse) error {
   241  	var responseLen int
   242  	responseLen += 1 // ok
   243  	responseLen += 4 // udp session id
   244  	responseLen += 2 // message len
   245  	responseLen += len(response.Message)
   246  	_buffer := buf.StackNewSize(responseLen)
   247  	defer common.KeepAlive(_buffer)
   248  	buffer := common.Dup(_buffer)
   249  	defer buffer.Release()
   250  	if response.OK {
   251  		common.Must(buffer.WriteByte(1))
   252  	} else {
   253  		common.Must(buffer.WriteByte(0))
   254  	}
   255  	common.Must(
   256  		binary.Write(buffer, binary.BigEndian, response.UDPSessionID),
   257  		binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))),
   258  		common.Error(buffer.WriteString(response.Message)),
   259  	)
   260  	return common.Error(stream.Write(buffer.Bytes()))
   261  }
   262  
   263  type UDPMessage struct {
   264  	SessionID uint32
   265  	Host      string
   266  	Port      uint16
   267  	MsgID     uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented
   268  	FragID    uint8  // doesn't matter when not fragmented, starts at 0 when fragmented
   269  	FragCount uint8  // must be 1 when not fragmented
   270  	Data      []byte
   271  }
   272  
   273  func (m UDPMessage) HeaderSize() int {
   274  	return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2
   275  }
   276  
   277  func (m UDPMessage) Size() int {
   278  	return m.HeaderSize() + len(m.Data)
   279  }
   280  
   281  func ParseUDPMessage(packet []byte) (message UDPMessage, err error) {
   282  	reader := bytes.NewReader(packet)
   283  	err = binary.Read(reader, binary.BigEndian, &message.SessionID)
   284  	if err != nil {
   285  		return
   286  	}
   287  	var hostLen uint16
   288  	err = binary.Read(reader, binary.BigEndian, &hostLen)
   289  	if err != nil {
   290  		return
   291  	}
   292  	_, err = reader.Seek(int64(hostLen), io.SeekCurrent)
   293  	if err != nil {
   294  		return
   295  	}
   296  	if 6+int(hostLen) > len(packet) {
   297  		err = E.New("invalid host length")
   298  		return
   299  	}
   300  	message.Host = string(packet[6 : 6+hostLen])
   301  	err = binary.Read(reader, binary.BigEndian, &message.Port)
   302  	if err != nil {
   303  		return
   304  	}
   305  	err = binary.Read(reader, binary.BigEndian, &message.MsgID)
   306  	if err != nil {
   307  		return
   308  	}
   309  	err = binary.Read(reader, binary.BigEndian, &message.FragID)
   310  	if err != nil {
   311  		return
   312  	}
   313  	err = binary.Read(reader, binary.BigEndian, &message.FragCount)
   314  	if err != nil {
   315  		return
   316  	}
   317  	var dataLen uint16
   318  	err = binary.Read(reader, binary.BigEndian, &dataLen)
   319  	if err != nil {
   320  		return
   321  	}
   322  	if reader.Len() != int(dataLen) {
   323  		err = E.New("invalid data length")
   324  	}
   325  	dataOffset := int(reader.Size()) - reader.Len()
   326  	message.Data = packet[dataOffset:]
   327  	return
   328  }
   329  
   330  func WriteUDPMessage(conn quic.Connection, message UDPMessage) error {
   331  	var messageLen int
   332  	messageLen += 4 // session id
   333  	messageLen += 2 // host len
   334  	messageLen += len(message.Host)
   335  	messageLen += 2 // port
   336  	messageLen += 2 // msg id
   337  	messageLen += 1 // frag id
   338  	messageLen += 1 // frag count
   339  	messageLen += 2 // data len
   340  	messageLen += len(message.Data)
   341  	_buffer := buf.StackNewSize(messageLen)
   342  	defer common.KeepAlive(_buffer)
   343  	buffer := common.Dup(_buffer)
   344  	defer buffer.Release()
   345  	err := writeUDPMessage(conn, message, buffer)
   346  	if errSize, ok := err.(quic.ErrMessageToLarge); ok {
   347  		// need to frag
   348  		message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
   349  		fragMsgs := FragUDPMessage(message, int(errSize))
   350  		for _, fragMsg := range fragMsgs {
   351  			buffer.FullReset()
   352  			err = writeUDPMessage(conn, fragMsg, buffer)
   353  			if err != nil {
   354  				return err
   355  			}
   356  		}
   357  		return nil
   358  	}
   359  	return err
   360  }
   361  
   362  func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffer) error {
   363  	common.Must(
   364  		binary.Write(buffer, binary.BigEndian, message.SessionID),
   365  		binary.Write(buffer, binary.BigEndian, uint16(len(message.Host))),
   366  		common.Error(buffer.WriteString(message.Host)),
   367  		binary.Write(buffer, binary.BigEndian, message.Port),
   368  		binary.Write(buffer, binary.BigEndian, message.MsgID),
   369  		binary.Write(buffer, binary.BigEndian, message.FragID),
   370  		binary.Write(buffer, binary.BigEndian, message.FragCount),
   371  		binary.Write(buffer, binary.BigEndian, uint16(len(message.Data))),
   372  		common.Error(buffer.Write(message.Data)),
   373  	)
   374  	return conn.SendMessage(buffer.Bytes())
   375  }
   376  
   377  var _ net.Conn = (*Conn)(nil)
   378  
   379  type Conn struct {
   380  	quic.Stream
   381  	destination      M.Socksaddr
   382  	needReadResponse bool
   383  }
   384  
   385  func NewConn(stream quic.Stream, destination M.Socksaddr, isClient bool) *Conn {
   386  	return &Conn{
   387  		Stream:           stream,
   388  		destination:      destination,
   389  		needReadResponse: isClient,
   390  	}
   391  }
   392  
   393  func (c *Conn) Read(p []byte) (n int, err error) {
   394  	if c.needReadResponse {
   395  		var response *ServerResponse
   396  		response, err = ReadServerResponse(c.Stream)
   397  		if err != nil {
   398  			c.Close()
   399  			return
   400  		}
   401  		if !response.OK {
   402  			c.Close()
   403  			return 0, E.New("remote error: ", response.Message)
   404  		}
   405  		c.needReadResponse = false
   406  	}
   407  	return c.Stream.Read(p)
   408  }
   409  
   410  func (c *Conn) LocalAddr() net.Addr {
   411  	return nil
   412  }
   413  
   414  func (c *Conn) RemoteAddr() net.Addr {
   415  	return c.destination.TCPAddr()
   416  }
   417  
   418  func (c *Conn) ReaderReplaceable() bool {
   419  	return !c.needReadResponse
   420  }
   421  
   422  func (c *Conn) WriterReplaceable() bool {
   423  	return true
   424  }
   425  
   426  func (c *Conn) Upstream() any {
   427  	return c.Stream
   428  }
   429  
   430  type PacketConn struct {
   431  	session     quic.Connection
   432  	stream      quic.Stream
   433  	sessionId   uint32
   434  	destination M.Socksaddr
   435  	msgCh       <-chan *UDPMessage
   436  	closer      io.Closer
   437  }
   438  
   439  func NewPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *PacketConn {
   440  	return &PacketConn{
   441  		session:     session,
   442  		stream:      stream,
   443  		sessionId:   sessionId,
   444  		destination: destination,
   445  		msgCh:       msgCh,
   446  		closer:      closer,
   447  	}
   448  }
   449  
   450  func (c *PacketConn) Hold() {
   451  	// Hold the stream until it's closed
   452  	buf := make([]byte, 1024)
   453  	for {
   454  		_, err := c.stream.Read(buf)
   455  		if err != nil {
   456  			break
   457  		}
   458  	}
   459  	_ = c.Close()
   460  }
   461  
   462  func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   463  	msg := <-c.msgCh
   464  	if msg == nil {
   465  		err = net.ErrClosed
   466  		return
   467  	}
   468  	err = common.Error(buffer.Write(msg.Data))
   469  	destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port).Unwrap()
   470  	return
   471  }
   472  
   473  func (c *PacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
   474  	msg := <-c.msgCh
   475  	if msg == nil {
   476  		err = net.ErrClosed
   477  		return
   478  	}
   479  	buffer = buf.As(msg.Data)
   480  	destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port).Unwrap()
   481  	return
   482  }
   483  
   484  func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   485  	return WriteUDPMessage(c.session, UDPMessage{
   486  		SessionID: c.sessionId,
   487  		Host:      destination.AddrString(),
   488  		Port:      destination.Port,
   489  		FragCount: 1,
   490  		Data:      buffer.Bytes(),
   491  	})
   492  }
   493  
   494  func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   495  	msg := <-c.msgCh
   496  	if msg == nil {
   497  		err = net.ErrClosed
   498  		return
   499  	}
   500  	n = copy(p, msg.Data)
   501  	destination := M.ParseSocksaddrHostPort(msg.Host, msg.Port)
   502  	if destination.IsFqdn() {
   503  		addr = destination
   504  	} else {
   505  		addr = destination.UDPAddr()
   506  	}
   507  	return
   508  }
   509  
   510  func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   511  	err = c.WritePacket(buf.As(p), M.SocksaddrFromNet(addr))
   512  	if err == nil {
   513  		n = len(p)
   514  	}
   515  	return
   516  }
   517  
   518  func (c *PacketConn) LocalAddr() net.Addr {
   519  	return nil
   520  }
   521  
   522  func (c *PacketConn) RemoteAddr() net.Addr {
   523  	return c.destination.UDPAddr()
   524  }
   525  
   526  func (c *PacketConn) SetDeadline(t time.Time) error {
   527  	return os.ErrInvalid
   528  }
   529  
   530  func (c *PacketConn) SetReadDeadline(t time.Time) error {
   531  	return os.ErrInvalid
   532  }
   533  
   534  func (c *PacketConn) SetWriteDeadline(t time.Time) error {
   535  	return os.ErrInvalid
   536  }
   537  
   538  func (c *PacketConn) NeedAdditionalReadDeadline() bool {
   539  	return true
   540  }
   541  
   542  func (c *PacketConn) Read(b []byte) (n int, err error) {
   543  	return 0, os.ErrInvalid
   544  }
   545  
   546  func (c *PacketConn) Write(b []byte) (n int, err error) {
   547  	return 0, os.ErrInvalid
   548  }
   549  
   550  func (c *PacketConn) Close() error {
   551  	return common.Close(c.stream, c.closer)
   552  }