github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/hysteria/protocol.go (about)

     1  package hysteria
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"io"
     7  	"math/rand"
     8  	"net"
     9  	"os"
    10  	"time"
    11  
    12  	"github.com/sagernet/quic-go"
    13  	"github.com/sagernet/sing/common"
    14  	"github.com/sagernet/sing/common/buf"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  )
    18  
    19  const (
    20  	MbpsToBps                      = 125000
    21  	MinSpeedBPS                    = 16384
    22  	DefaultStreamReceiveWindow     = 15728640 // 15 MB/s
    23  	DefaultConnectionReceiveWindow = 67108864 // 64 MB/s
    24  	DefaultMaxIncomingStreams      = 1024
    25  	DefaultALPN                    = "hysteria"
    26  	KeepAlivePeriod                = 10 * time.Second
    27  )
    28  
    29  const Version = 3
    30  
    31  type ClientHello struct {
    32  	SendBPS uint64
    33  	RecvBPS uint64
    34  	Auth    []byte
    35  }
    36  
    37  func WriteClientHello(stream io.Writer, hello ClientHello) error {
    38  	var requestLen int
    39  	requestLen += 1 // version
    40  	requestLen += 8 // sendBPS
    41  	requestLen += 8 // recvBPS
    42  	requestLen += 2 // auth len
    43  	requestLen += len(hello.Auth)
    44  	request := buf.NewSize(requestLen)
    45  	defer request.Release()
    46  	common.Must(
    47  		request.WriteByte(Version),
    48  		binary.Write(request, binary.BigEndian, hello.SendBPS),
    49  		binary.Write(request, binary.BigEndian, hello.RecvBPS),
    50  		binary.Write(request, binary.BigEndian, uint16(len(hello.Auth))),
    51  		common.Error(request.Write(hello.Auth)),
    52  	)
    53  	return common.Error(stream.Write(request.Bytes()))
    54  }
    55  
    56  func ReadClientHello(reader io.Reader) (*ClientHello, error) {
    57  	var version uint8
    58  	err := binary.Read(reader, binary.BigEndian, &version)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	if version != Version {
    63  		return nil, E.New("unsupported client version: ", version)
    64  	}
    65  	var clientHello ClientHello
    66  	err = binary.Read(reader, binary.BigEndian, &clientHello.SendBPS)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  	err = binary.Read(reader, binary.BigEndian, &clientHello.RecvBPS)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  	var authLen uint16
    75  	err = binary.Read(reader, binary.BigEndian, &authLen)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	clientHello.Auth = make([]byte, authLen)
    80  	_, err = io.ReadFull(reader, clientHello.Auth)
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  	return &clientHello, nil
    85  }
    86  
    87  type ServerHello struct {
    88  	OK      bool
    89  	SendBPS uint64
    90  	RecvBPS uint64
    91  	Message string
    92  }
    93  
    94  func ReadServerHello(stream io.Reader) (*ServerHello, error) {
    95  	var responseLen int
    96  	responseLen += 1 // ok
    97  	responseLen += 8 // sendBPS
    98  	responseLen += 8 // recvBPS
    99  	responseLen += 2 // message len
   100  	response := buf.NewSize(responseLen)
   101  	defer response.Release()
   102  	_, err := response.ReadFullFrom(stream, responseLen)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  	var serverHello ServerHello
   107  	serverHello.OK = response.Byte(0) == 1
   108  	serverHello.SendBPS = binary.BigEndian.Uint64(response.Range(1, 9))
   109  	serverHello.RecvBPS = binary.BigEndian.Uint64(response.Range(9, 17))
   110  	messageLen := binary.BigEndian.Uint16(response.Range(17, 19))
   111  	if messageLen == 0 {
   112  		return &serverHello, nil
   113  	}
   114  	message := make([]byte, messageLen)
   115  	_, err = io.ReadFull(stream, message)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  	serverHello.Message = string(message)
   120  	return &serverHello, nil
   121  }
   122  
   123  func WriteServerHello(stream io.Writer, hello ServerHello) error {
   124  	var responseLen int
   125  	responseLen += 1 // ok
   126  	responseLen += 8 // sendBPS
   127  	responseLen += 8 // recvBPS
   128  	responseLen += 2 // message len
   129  	responseLen += len(hello.Message)
   130  	response := buf.NewSize(responseLen)
   131  	defer response.Release()
   132  	if hello.OK {
   133  		common.Must(response.WriteByte(1))
   134  	} else {
   135  		common.Must(response.WriteByte(0))
   136  	}
   137  	common.Must(
   138  		binary.Write(response, binary.BigEndian, hello.SendBPS),
   139  		binary.Write(response, binary.BigEndian, hello.RecvBPS),
   140  		binary.Write(response, binary.BigEndian, uint16(len(hello.Message))),
   141  		common.Error(response.WriteString(hello.Message)),
   142  	)
   143  	return common.Error(stream.Write(response.Bytes()))
   144  }
   145  
   146  type ClientRequest struct {
   147  	UDP  bool
   148  	Host string
   149  	Port uint16
   150  }
   151  
   152  func ReadClientRequest(stream io.Reader) (*ClientRequest, error) {
   153  	var clientRequest ClientRequest
   154  	err := binary.Read(stream, binary.BigEndian, &clientRequest.UDP)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  	var hostLen uint16
   159  	err = binary.Read(stream, binary.BigEndian, &hostLen)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  	host := make([]byte, hostLen)
   164  	_, err = io.ReadFull(stream, host)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	clientRequest.Host = string(host)
   169  	err = binary.Read(stream, binary.BigEndian, &clientRequest.Port)
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  	return &clientRequest, nil
   174  }
   175  
   176  func WriteClientRequest(stream io.Writer, request ClientRequest) error {
   177  	var requestLen int
   178  	requestLen += 1 // udp
   179  	requestLen += 2 // host len
   180  	requestLen += len(request.Host)
   181  	requestLen += 2 // port
   182  	buffer := buf.NewSize(requestLen)
   183  	defer buffer.Release()
   184  	if request.UDP {
   185  		common.Must(buffer.WriteByte(1))
   186  	} else {
   187  		common.Must(buffer.WriteByte(0))
   188  	}
   189  	common.Must(
   190  		binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))),
   191  		common.Error(buffer.WriteString(request.Host)),
   192  		binary.Write(buffer, binary.BigEndian, request.Port),
   193  	)
   194  	return common.Error(stream.Write(buffer.Bytes()))
   195  }
   196  
   197  type ServerResponse struct {
   198  	OK           bool
   199  	UDPSessionID uint32
   200  	Message      string
   201  }
   202  
   203  func ReadServerResponse(stream io.Reader) (*ServerResponse, error) {
   204  	var responseLen int
   205  	responseLen += 1 // ok
   206  	responseLen += 4 // udp session id
   207  	responseLen += 2 // message len
   208  	response := buf.NewSize(responseLen)
   209  	defer response.Release()
   210  	_, err := response.ReadFullFrom(stream, responseLen)
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  	var serverResponse ServerResponse
   215  	serverResponse.OK = response.Byte(0) == 1
   216  	serverResponse.UDPSessionID = binary.BigEndian.Uint32(response.Range(1, 5))
   217  	messageLen := binary.BigEndian.Uint16(response.Range(5, 7))
   218  	if messageLen == 0 {
   219  		return &serverResponse, nil
   220  	}
   221  	message := make([]byte, messageLen)
   222  	_, err = io.ReadFull(stream, message)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  	serverResponse.Message = string(message)
   227  	return &serverResponse, nil
   228  }
   229  
   230  func WriteServerResponse(stream io.Writer, response ServerResponse) error {
   231  	var responseLen int
   232  	responseLen += 1 // ok
   233  	responseLen += 4 // udp session id
   234  	responseLen += 2 // message len
   235  	responseLen += len(response.Message)
   236  	buffer := buf.NewSize(responseLen)
   237  	defer buffer.Release()
   238  	if response.OK {
   239  		common.Must(buffer.WriteByte(1))
   240  	} else {
   241  		common.Must(buffer.WriteByte(0))
   242  	}
   243  	common.Must(
   244  		binary.Write(buffer, binary.BigEndian, response.UDPSessionID),
   245  		binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))),
   246  		common.Error(buffer.WriteString(response.Message)),
   247  	)
   248  	return common.Error(stream.Write(buffer.Bytes()))
   249  }
   250  
   251  type UDPMessage struct {
   252  	SessionID uint32
   253  	Host      string
   254  	Port      uint16
   255  	MsgID     uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented
   256  	FragID    uint8  // doesn't matter when not fragmented, starts at 0 when fragmented
   257  	FragCount uint8  // must be 1 when not fragmented
   258  	Data      []byte
   259  }
   260  
   261  func (m UDPMessage) HeaderSize() int {
   262  	return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2
   263  }
   264  
   265  func (m UDPMessage) Size() int {
   266  	return m.HeaderSize() + len(m.Data)
   267  }
   268  
   269  func ParseUDPMessage(packet []byte) (message UDPMessage, err error) {
   270  	reader := bytes.NewReader(packet)
   271  	err = binary.Read(reader, binary.BigEndian, &message.SessionID)
   272  	if err != nil {
   273  		return
   274  	}
   275  	var hostLen uint16
   276  	err = binary.Read(reader, binary.BigEndian, &hostLen)
   277  	if err != nil {
   278  		return
   279  	}
   280  	_, err = reader.Seek(int64(hostLen), io.SeekCurrent)
   281  	if err != nil {
   282  		return
   283  	}
   284  	if 6+int(hostLen) > len(packet) {
   285  		err = E.New("invalid host length")
   286  		return
   287  	}
   288  	message.Host = string(packet[6 : 6+hostLen])
   289  	err = binary.Read(reader, binary.BigEndian, &message.Port)
   290  	if err != nil {
   291  		return
   292  	}
   293  	err = binary.Read(reader, binary.BigEndian, &message.MsgID)
   294  	if err != nil {
   295  		return
   296  	}
   297  	err = binary.Read(reader, binary.BigEndian, &message.FragID)
   298  	if err != nil {
   299  		return
   300  	}
   301  	err = binary.Read(reader, binary.BigEndian, &message.FragCount)
   302  	if err != nil {
   303  		return
   304  	}
   305  	var dataLen uint16
   306  	err = binary.Read(reader, binary.BigEndian, &dataLen)
   307  	if err != nil {
   308  		return
   309  	}
   310  	if reader.Len() != int(dataLen) {
   311  		err = E.New("invalid data length")
   312  	}
   313  	dataOffset := int(reader.Size()) - reader.Len()
   314  	message.Data = packet[dataOffset:]
   315  	return
   316  }
   317  
   318  func WriteUDPMessage(conn quic.Connection, message UDPMessage) error {
   319  	var messageLen int
   320  	messageLen += 4 // session id
   321  	messageLen += 2 // host len
   322  	messageLen += len(message.Host)
   323  	messageLen += 2 // port
   324  	messageLen += 2 // msg id
   325  	messageLen += 1 // frag id
   326  	messageLen += 1 // frag count
   327  	messageLen += 2 // data len
   328  	messageLen += len(message.Data)
   329  	buffer := buf.NewSize(messageLen)
   330  	defer buffer.Release()
   331  	err := writeUDPMessage(conn, message, buffer)
   332  	if errSize, ok := err.(quic.ErrMessageTooLarge); ok {
   333  		// need to frag
   334  		message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
   335  		fragMsgs := FragUDPMessage(message, int(errSize))
   336  		for _, fragMsg := range fragMsgs {
   337  			buffer.FullReset()
   338  			err = writeUDPMessage(conn, fragMsg, buffer)
   339  			if err != nil {
   340  				return err
   341  			}
   342  		}
   343  		return nil
   344  	}
   345  	return err
   346  }
   347  
   348  func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffer) error {
   349  	common.Must(
   350  		binary.Write(buffer, binary.BigEndian, message.SessionID),
   351  		binary.Write(buffer, binary.BigEndian, uint16(len(message.Host))),
   352  		common.Error(buffer.WriteString(message.Host)),
   353  		binary.Write(buffer, binary.BigEndian, message.Port),
   354  		binary.Write(buffer, binary.BigEndian, message.MsgID),
   355  		binary.Write(buffer, binary.BigEndian, message.FragID),
   356  		binary.Write(buffer, binary.BigEndian, message.FragCount),
   357  		binary.Write(buffer, binary.BigEndian, uint16(len(message.Data))),
   358  		common.Error(buffer.Write(message.Data)),
   359  	)
   360  	return conn.SendMessage(buffer.Bytes())
   361  }
   362  
   363  var _ net.Conn = (*Conn)(nil)
   364  
   365  type Conn struct {
   366  	quic.Stream
   367  	destination      M.Socksaddr
   368  	needReadResponse bool
   369  }
   370  
   371  func NewConn(stream quic.Stream, destination M.Socksaddr, isClient bool) *Conn {
   372  	return &Conn{
   373  		Stream:           stream,
   374  		destination:      destination,
   375  		needReadResponse: isClient,
   376  	}
   377  }
   378  
   379  func (c *Conn) Read(p []byte) (n int, err error) {
   380  	if c.needReadResponse {
   381  		var response *ServerResponse
   382  		response, err = ReadServerResponse(c.Stream)
   383  		if err != nil {
   384  			c.Close()
   385  			return
   386  		}
   387  		if !response.OK {
   388  			c.Close()
   389  			return 0, E.New("remote error: ", response.Message)
   390  		}
   391  		c.needReadResponse = false
   392  	}
   393  	return c.Stream.Read(p)
   394  }
   395  
   396  func (c *Conn) LocalAddr() net.Addr {
   397  	return M.Socksaddr{}
   398  }
   399  
   400  func (c *Conn) RemoteAddr() net.Addr {
   401  	return c.destination.TCPAddr()
   402  }
   403  
   404  func (c *Conn) ReaderReplaceable() bool {
   405  	return !c.needReadResponse
   406  }
   407  
   408  func (c *Conn) WriterReplaceable() bool {
   409  	return true
   410  }
   411  
   412  func (c *Conn) Upstream() any {
   413  	return c.Stream
   414  }
   415  
   416  type PacketConn struct {
   417  	session     quic.Connection
   418  	stream      quic.Stream
   419  	sessionId   uint32
   420  	destination M.Socksaddr
   421  	msgCh       <-chan *UDPMessage
   422  	closer      io.Closer
   423  }
   424  
   425  func NewPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *PacketConn {
   426  	return &PacketConn{
   427  		session:     session,
   428  		stream:      stream,
   429  		sessionId:   sessionId,
   430  		destination: destination,
   431  		msgCh:       msgCh,
   432  		closer:      closer,
   433  	}
   434  }
   435  
   436  func (c *PacketConn) Hold() {
   437  	// Hold the stream until it's closed
   438  	buf := make([]byte, 1024)
   439  	for {
   440  		_, err := c.stream.Read(buf)
   441  		if err != nil {
   442  			break
   443  		}
   444  	}
   445  	_ = c.Close()
   446  }
   447  
   448  func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   449  	msg := <-c.msgCh
   450  	if msg == nil {
   451  		err = net.ErrClosed
   452  		return
   453  	}
   454  	err = common.Error(buffer.Write(msg.Data))
   455  	destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port).Unwrap()
   456  	return
   457  }
   458  
   459  func (c *PacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
   460  	msg := <-c.msgCh
   461  	if msg == nil {
   462  		err = net.ErrClosed
   463  		return
   464  	}
   465  	buffer = buf.As(msg.Data)
   466  	destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port).Unwrap()
   467  	return
   468  }
   469  
   470  func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   471  	return WriteUDPMessage(c.session, UDPMessage{
   472  		SessionID: c.sessionId,
   473  		Host:      destination.AddrString(),
   474  		Port:      destination.Port,
   475  		FragCount: 1,
   476  		Data:      buffer.Bytes(),
   477  	})
   478  }
   479  
   480  func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   481  	msg := <-c.msgCh
   482  	if msg == nil {
   483  		err = net.ErrClosed
   484  		return
   485  	}
   486  	n = copy(p, msg.Data)
   487  	destination := M.ParseSocksaddrHostPort(msg.Host, msg.Port)
   488  	if destination.IsFqdn() {
   489  		addr = destination
   490  	} else {
   491  		addr = destination.UDPAddr()
   492  	}
   493  	return
   494  }
   495  
   496  func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   497  	err = c.WritePacket(buf.As(p), M.SocksaddrFromNet(addr))
   498  	if err == nil {
   499  		n = len(p)
   500  	}
   501  	return
   502  }
   503  
   504  func (c *PacketConn) LocalAddr() net.Addr {
   505  	return M.Socksaddr{}
   506  }
   507  
   508  func (c *PacketConn) RemoteAddr() net.Addr {
   509  	return c.destination.UDPAddr()
   510  }
   511  
   512  func (c *PacketConn) SetDeadline(t time.Time) error {
   513  	return os.ErrInvalid
   514  }
   515  
   516  func (c *PacketConn) SetReadDeadline(t time.Time) error {
   517  	return os.ErrInvalid
   518  }
   519  
   520  func (c *PacketConn) SetWriteDeadline(t time.Time) error {
   521  	return os.ErrInvalid
   522  }
   523  
   524  func (c *PacketConn) NeedAdditionalReadDeadline() bool {
   525  	return true
   526  }
   527  
   528  func (c *PacketConn) Read(b []byte) (n int, err error) {
   529  	n, _, err = c.ReadFrom(b)
   530  	return
   531  }
   532  
   533  func (c *PacketConn) Write(b []byte) (n int, err error) {
   534  	return c.WriteTo(b, c.destination)
   535  }
   536  
   537  func (c *PacketConn) Close() error {
   538  	return common.Close(c.stream, c.closer)
   539  }