github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/hysteria2/internal/protocol/proxy.go (about)

     1  package protocol
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  
     9  	"github.com/sagernet/quic-go/quicvarint"
    10  	"github.com/sagernet/sing/common"
    11  	"github.com/sagernet/sing/common/buf"
    12  	E "github.com/sagernet/sing/common/exceptions"
    13  	"github.com/sagernet/sing/common/rw"
    14  )
    15  
    16  const (
    17  	FrameTypeTCPRequest = 0x401
    18  
    19  	// Max length values are for preventing DoS attacks
    20  
    21  	MaxAddressLength = 2048
    22  	MaxMessageLength = 2048
    23  	MaxPaddingLength = 4096
    24  
    25  	MaxUDPSize = 4096
    26  
    27  	maxVarInt1 = 63
    28  	maxVarInt2 = 16383
    29  	maxVarInt4 = 1073741823
    30  	maxVarInt8 = 4611686018427387903
    31  )
    32  
    33  // TCPRequest format:
    34  // 0x401 (QUIC varint)
    35  // Address length (QUIC varint)
    36  // Address (bytes)
    37  // Padding length (QUIC varint)
    38  // Padding (bytes)
    39  
    40  func ReadTCPRequest(r io.Reader) (string, error) {
    41  	bReader := quicvarint.NewReader(r)
    42  	addrLen, err := quicvarint.Read(bReader)
    43  	if err != nil {
    44  		return "", err
    45  	}
    46  	if addrLen == 0 || addrLen > MaxAddressLength {
    47  		return "", E.New("invalid address length")
    48  	}
    49  	addrBuf := make([]byte, addrLen)
    50  	_, err = io.ReadFull(r, addrBuf)
    51  	if err != nil {
    52  		return "", err
    53  	}
    54  	paddingLen, err := quicvarint.Read(bReader)
    55  	if err != nil {
    56  		return "", err
    57  	}
    58  	if paddingLen > MaxPaddingLength {
    59  		return "", E.New("invalid padding length")
    60  	}
    61  	if paddingLen > 0 {
    62  		_, err = io.CopyN(io.Discard, r, int64(paddingLen))
    63  		if err != nil {
    64  			return "", err
    65  		}
    66  	}
    67  	return string(addrBuf), nil
    68  }
    69  
    70  func WriteTCPRequest(addr string, payload []byte) *buf.Buffer {
    71  	padding := tcpRequestPadding.String()
    72  	paddingLen := len(padding)
    73  	addrLen := len(addr)
    74  	sz := int(quicvarint.Len(FrameTypeTCPRequest)) +
    75  		int(quicvarint.Len(uint64(addrLen))) + addrLen +
    76  		int(quicvarint.Len(uint64(paddingLen))) + paddingLen
    77  	buffer := buf.NewSize(sz + len(payload))
    78  	bufferContent := buffer.Extend(sz)
    79  	i := varintPut(bufferContent, FrameTypeTCPRequest)
    80  	i += varintPut(bufferContent[i:], uint64(addrLen))
    81  	i += copy(bufferContent[i:], addr)
    82  	i += varintPut(bufferContent[i:], uint64(paddingLen))
    83  	copy(bufferContent[i:], padding)
    84  	buffer.Write(payload)
    85  	return buffer
    86  }
    87  
    88  // TCPResponse format:
    89  // Status (byte, 0=ok, 1=error)
    90  // Message length (QUIC varint)
    91  // Message (bytes)
    92  // Padding length (QUIC varint)
    93  // Padding (bytes)
    94  
    95  func ReadTCPResponse(r io.Reader) (bool, string, error) {
    96  	var status [1]byte
    97  	if _, err := io.ReadFull(r, status[:]); err != nil {
    98  		return false, "", err
    99  	}
   100  	bReader := quicvarint.NewReader(r)
   101  	msg, err := ReadVString(bReader)
   102  	if err != nil {
   103  		return false, "", err
   104  	}
   105  	paddingLen, err := quicvarint.Read(bReader)
   106  	if err != nil {
   107  		return false, "", err
   108  	}
   109  	if paddingLen > MaxPaddingLength {
   110  		return false, "", E.New("invalid padding length")
   111  	}
   112  	if paddingLen > 0 {
   113  		_, err = io.CopyN(io.Discard, r, int64(paddingLen))
   114  		if err != nil {
   115  			return false, "", err
   116  		}
   117  	}
   118  	return status[0] == 0, msg, nil
   119  }
   120  
   121  func WriteTCPResponse(ok bool, msg string, payload []byte) *buf.Buffer {
   122  	padding := tcpResponsePadding.String()
   123  	paddingLen := len(padding)
   124  	msgLen := len(msg)
   125  	sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
   126  		int(quicvarint.Len(uint64(paddingLen))) + paddingLen
   127  	buffer := buf.NewSize(sz + len(payload))
   128  	if ok {
   129  		buffer.WriteByte(0)
   130  	} else {
   131  		buffer.WriteByte(1)
   132  	}
   133  	WriteVString(buffer, msg)
   134  	WriteUVariant(buffer, uint64(paddingLen))
   135  	buffer.Extend(paddingLen)
   136  	buffer.Write(payload)
   137  	return buffer
   138  }
   139  
   140  // UDPMessage format:
   141  // Session ID (uint32 BE)
   142  // Packet ID (uint16 BE)
   143  // Fragment ID (uint8)
   144  // Fragment count (uint8)
   145  // Address length (QUIC varint)
   146  // Address (bytes)
   147  // Data...
   148  
   149  type UDPMessage struct {
   150  	SessionID uint32 // 4
   151  	PacketID  uint16 // 2
   152  	FragID    uint8  // 1
   153  	FragCount uint8  // 1
   154  	Addr      string // varint + bytes
   155  	Data      []byte
   156  }
   157  
   158  func (m *UDPMessage) HeaderSize() int {
   159  	lAddr := len(m.Addr)
   160  	return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr
   161  }
   162  
   163  func (m *UDPMessage) Size() int {
   164  	return m.HeaderSize() + len(m.Data)
   165  }
   166  
   167  func (m *UDPMessage) Serialize(buf []byte) int {
   168  	// Make sure the buffer is big enough
   169  	if len(buf) < m.Size() {
   170  		return -1
   171  	}
   172  	binary.BigEndian.PutUint32(buf, m.SessionID)
   173  	binary.BigEndian.PutUint16(buf[4:], m.PacketID)
   174  	buf[6] = m.FragID
   175  	buf[7] = m.FragCount
   176  	i := varintPut(buf[8:], uint64(len(m.Addr)))
   177  	i += copy(buf[8+i:], m.Addr)
   178  	i += copy(buf[8+i:], m.Data)
   179  	return 8 + i
   180  }
   181  
   182  func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
   183  	m := &UDPMessage{}
   184  	buf := bytes.NewBuffer(msg)
   185  	if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil {
   186  		return nil, err
   187  	}
   188  	if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil {
   189  		return nil, err
   190  	}
   191  	if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil {
   192  		return nil, err
   193  	}
   194  	if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil {
   195  		return nil, err
   196  	}
   197  	lAddr, err := quicvarint.Read(buf)
   198  	if err != nil {
   199  		return nil, err
   200  	}
   201  	if lAddr == 0 || lAddr > MaxMessageLength {
   202  		return nil, E.New("invalid address length")
   203  	}
   204  	bs := buf.Bytes()
   205  	m.Addr = string(bs[:lAddr])
   206  	m.Data = bs[lAddr:]
   207  	return m, nil
   208  }
   209  
   210  func ReadVString(reader io.Reader) (string, error) {
   211  	length, err := quicvarint.Read(quicvarint.NewReader(reader))
   212  	if err != nil {
   213  		return "", err
   214  	}
   215  	value, err := rw.ReadBytes(reader, int(length))
   216  	if err != nil {
   217  		return "", err
   218  	}
   219  	return string(value), nil
   220  }
   221  
   222  func WriteVString(writer io.Writer, value string) error {
   223  	err := WriteUVariant(writer, uint64(len(value)))
   224  	if err != nil {
   225  		return err
   226  	}
   227  	return rw.WriteString(writer, value)
   228  }
   229  
   230  func WriteUVariant(writer io.Writer, value uint64) error {
   231  	var b [8]byte
   232  	return common.Error(writer.Write(b[:varintPut(b[:], value)]))
   233  }
   234  
   235  // varintPut is like quicvarint.Append, but instead of appending to a slice,
   236  // it writes to a fixed-size buffer. Returns the number of bytes written.
   237  func varintPut(b []byte, i uint64) int {
   238  	if i <= maxVarInt1 {
   239  		b[0] = uint8(i)
   240  		return 1
   241  	}
   242  	if i <= maxVarInt2 {
   243  		b[0] = uint8(i>>8) | 0x40
   244  		b[1] = uint8(i)
   245  		return 2
   246  	}
   247  	if i <= maxVarInt4 {
   248  		b[0] = uint8(i>>24) | 0x80
   249  		b[1] = uint8(i >> 16)
   250  		b[2] = uint8(i >> 8)
   251  		b[3] = uint8(i)
   252  		return 4
   253  	}
   254  	if i <= maxVarInt8 {
   255  		b[0] = uint8(i>>56) | 0xc0
   256  		b[1] = uint8(i >> 48)
   257  		b[2] = uint8(i >> 40)
   258  		b[3] = uint8(i >> 32)
   259  		b[4] = uint8(i >> 24)
   260  		b[5] = uint8(i >> 16)
   261  		b[6] = uint8(i >> 8)
   262  		b[7] = uint8(i)
   263  		return 8
   264  	}
   265  	panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
   266  }