github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/hysteria2/internal/protocol/proxy.go (about) 1 package protocol 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "fmt" 7 "io" 8 9 "github.com/sagernet/quic-go/quicvarint" 10 "github.com/sagernet/sing/common" 11 "github.com/sagernet/sing/common/buf" 12 E "github.com/sagernet/sing/common/exceptions" 13 "github.com/sagernet/sing/common/rw" 14 ) 15 16 const ( 17 FrameTypeTCPRequest = 0x401 18 19 // Max length values are for preventing DoS attacks 20 21 MaxAddressLength = 2048 22 MaxMessageLength = 2048 23 MaxPaddingLength = 4096 24 25 MaxUDPSize = 4096 26 27 maxVarInt1 = 63 28 maxVarInt2 = 16383 29 maxVarInt4 = 1073741823 30 maxVarInt8 = 4611686018427387903 31 ) 32 33 // TCPRequest format: 34 // 0x401 (QUIC varint) 35 // Address length (QUIC varint) 36 // Address (bytes) 37 // Padding length (QUIC varint) 38 // Padding (bytes) 39 40 func ReadTCPRequest(r io.Reader) (string, error) { 41 bReader := quicvarint.NewReader(r) 42 addrLen, err := quicvarint.Read(bReader) 43 if err != nil { 44 return "", err 45 } 46 if addrLen == 0 || addrLen > MaxAddressLength { 47 return "", E.New("invalid address length") 48 } 49 addrBuf := make([]byte, addrLen) 50 _, err = io.ReadFull(r, addrBuf) 51 if err != nil { 52 return "", err 53 } 54 paddingLen, err := quicvarint.Read(bReader) 55 if err != nil { 56 return "", err 57 } 58 if paddingLen > MaxPaddingLength { 59 return "", E.New("invalid padding length") 60 } 61 if paddingLen > 0 { 62 _, err = io.CopyN(io.Discard, r, int64(paddingLen)) 63 if err != nil { 64 return "", err 65 } 66 } 67 return string(addrBuf), nil 68 } 69 70 func WriteTCPRequest(addr string, payload []byte) *buf.Buffer { 71 padding := tcpRequestPadding.String() 72 paddingLen := len(padding) 73 addrLen := len(addr) 74 sz := int(quicvarint.Len(FrameTypeTCPRequest)) + 75 int(quicvarint.Len(uint64(addrLen))) + addrLen + 76 int(quicvarint.Len(uint64(paddingLen))) + paddingLen 77 buffer := buf.NewSize(sz + len(payload)) 78 bufferContent := buffer.Extend(sz) 79 i := varintPut(bufferContent, FrameTypeTCPRequest) 80 i += varintPut(bufferContent[i:], uint64(addrLen)) 81 i += copy(bufferContent[i:], addr) 82 i += varintPut(bufferContent[i:], uint64(paddingLen)) 83 copy(bufferContent[i:], padding) 84 buffer.Write(payload) 85 return buffer 86 } 87 88 // TCPResponse format: 89 // Status (byte, 0=ok, 1=error) 90 // Message length (QUIC varint) 91 // Message (bytes) 92 // Padding length (QUIC varint) 93 // Padding (bytes) 94 95 func ReadTCPResponse(r io.Reader) (bool, string, error) { 96 var status [1]byte 97 if _, err := io.ReadFull(r, status[:]); err != nil { 98 return false, "", err 99 } 100 bReader := quicvarint.NewReader(r) 101 msg, err := ReadVString(bReader) 102 if err != nil { 103 return false, "", err 104 } 105 paddingLen, err := quicvarint.Read(bReader) 106 if err != nil { 107 return false, "", err 108 } 109 if paddingLen > MaxPaddingLength { 110 return false, "", E.New("invalid padding length") 111 } 112 if paddingLen > 0 { 113 _, err = io.CopyN(io.Discard, r, int64(paddingLen)) 114 if err != nil { 115 return false, "", err 116 } 117 } 118 return status[0] == 0, msg, nil 119 } 120 121 func WriteTCPResponse(ok bool, msg string, payload []byte) *buf.Buffer { 122 padding := tcpResponsePadding.String() 123 paddingLen := len(padding) 124 msgLen := len(msg) 125 sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + 126 int(quicvarint.Len(uint64(paddingLen))) + paddingLen 127 buffer := buf.NewSize(sz + len(payload)) 128 if ok { 129 buffer.WriteByte(0) 130 } else { 131 buffer.WriteByte(1) 132 } 133 WriteVString(buffer, msg) 134 WriteUVariant(buffer, uint64(paddingLen)) 135 buffer.Extend(paddingLen) 136 buffer.Write(payload) 137 return buffer 138 } 139 140 // UDPMessage format: 141 // Session ID (uint32 BE) 142 // Packet ID (uint16 BE) 143 // Fragment ID (uint8) 144 // Fragment count (uint8) 145 // Address length (QUIC varint) 146 // Address (bytes) 147 // Data... 148 149 type UDPMessage struct { 150 SessionID uint32 // 4 151 PacketID uint16 // 2 152 FragID uint8 // 1 153 FragCount uint8 // 1 154 Addr string // varint + bytes 155 Data []byte 156 } 157 158 func (m *UDPMessage) HeaderSize() int { 159 lAddr := len(m.Addr) 160 return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr 161 } 162 163 func (m *UDPMessage) Size() int { 164 return m.HeaderSize() + len(m.Data) 165 } 166 167 func (m *UDPMessage) Serialize(buf []byte) int { 168 // Make sure the buffer is big enough 169 if len(buf) < m.Size() { 170 return -1 171 } 172 binary.BigEndian.PutUint32(buf, m.SessionID) 173 binary.BigEndian.PutUint16(buf[4:], m.PacketID) 174 buf[6] = m.FragID 175 buf[7] = m.FragCount 176 i := varintPut(buf[8:], uint64(len(m.Addr))) 177 i += copy(buf[8+i:], m.Addr) 178 i += copy(buf[8+i:], m.Data) 179 return 8 + i 180 } 181 182 func ParseUDPMessage(msg []byte) (*UDPMessage, error) { 183 m := &UDPMessage{} 184 buf := bytes.NewBuffer(msg) 185 if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil { 186 return nil, err 187 } 188 if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil { 189 return nil, err 190 } 191 if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil { 192 return nil, err 193 } 194 if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil { 195 return nil, err 196 } 197 lAddr, err := quicvarint.Read(buf) 198 if err != nil { 199 return nil, err 200 } 201 if lAddr == 0 || lAddr > MaxMessageLength { 202 return nil, E.New("invalid address length") 203 } 204 bs := buf.Bytes() 205 m.Addr = string(bs[:lAddr]) 206 m.Data = bs[lAddr:] 207 return m, nil 208 } 209 210 func ReadVString(reader io.Reader) (string, error) { 211 length, err := quicvarint.Read(quicvarint.NewReader(reader)) 212 if err != nil { 213 return "", err 214 } 215 value, err := rw.ReadBytes(reader, int(length)) 216 if err != nil { 217 return "", err 218 } 219 return string(value), nil 220 } 221 222 func WriteVString(writer io.Writer, value string) error { 223 err := WriteUVariant(writer, uint64(len(value))) 224 if err != nil { 225 return err 226 } 227 return rw.WriteString(writer, value) 228 } 229 230 func WriteUVariant(writer io.Writer, value uint64) error { 231 var b [8]byte 232 return common.Error(writer.Write(b[:varintPut(b[:], value)])) 233 } 234 235 // varintPut is like quicvarint.Append, but instead of appending to a slice, 236 // it writes to a fixed-size buffer. Returns the number of bytes written. 237 func varintPut(b []byte, i uint64) int { 238 if i <= maxVarInt1 { 239 b[0] = uint8(i) 240 return 1 241 } 242 if i <= maxVarInt2 { 243 b[0] = uint8(i>>8) | 0x40 244 b[1] = uint8(i) 245 return 2 246 } 247 if i <= maxVarInt4 { 248 b[0] = uint8(i>>24) | 0x80 249 b[1] = uint8(i >> 16) 250 b[2] = uint8(i >> 8) 251 b[3] = uint8(i) 252 return 4 253 } 254 if i <= maxVarInt8 { 255 b[0] = uint8(i>>56) | 0xc0 256 b[1] = uint8(i >> 48) 257 b[2] = uint8(i >> 40) 258 b[3] = uint8(i >> 32) 259 b[4] = uint8(i >> 24) 260 b[5] = uint8(i >> 16) 261 b[6] = uint8(i >> 8) 262 b[7] = uint8(i) 263 return 8 264 } 265 panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) 266 }