github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/hysteria2/packet.go (about) 1 package hysteria2 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/binary" 7 "errors" 8 "github.com/sagernet/quic-go" 9 "io" 10 "math" 11 "net" 12 "os" 13 "sync" 14 "time" 15 16 "github.com/inazumav/sing-box/transport/hysteria2/internal/protocol" 17 "github.com/sagernet/quic-go/quicvarint" 18 "github.com/sagernet/sing/common" 19 "github.com/sagernet/sing/common/atomic" 20 "github.com/sagernet/sing/common/buf" 21 "github.com/sagernet/sing/common/cache" 22 M "github.com/sagernet/sing/common/metadata" 23 ) 24 25 var udpMessagePool = sync.Pool{ 26 New: func() interface{} { 27 return new(udpMessage) 28 }, 29 } 30 31 func releaseMessages(messages []*udpMessage) { 32 for _, message := range messages { 33 if message != nil { 34 *message = udpMessage{} 35 udpMessagePool.Put(message) 36 } 37 } 38 } 39 40 type udpMessage struct { 41 sessionID uint32 42 packetID uint16 43 fragmentID uint8 44 fragmentTotal uint8 45 destination string 46 data *buf.Buffer 47 } 48 49 func (m *udpMessage) release() { 50 *m = udpMessage{} 51 udpMessagePool.Put(m) 52 } 53 54 func (m *udpMessage) releaseMessage() { 55 m.data.Release() 56 m.release() 57 } 58 59 func (m *udpMessage) pack() *buf.Buffer { 60 buffer := buf.NewSize(m.headerSize() + m.data.Len()) 61 common.Must( 62 binary.Write(buffer, binary.BigEndian, m.sessionID), 63 binary.Write(buffer, binary.BigEndian, m.packetID), 64 binary.Write(buffer, binary.BigEndian, m.fragmentID), 65 binary.Write(buffer, binary.BigEndian, m.fragmentTotal), 66 protocol.WriteVString(buffer, m.destination), 67 common.Error(buffer.Write(m.data.Bytes())), 68 ) 69 return buffer 70 } 71 72 func (m *udpMessage) headerSize() int { 73 return 8 + int(quicvarint.Len(uint64(len(m.destination)))) + len(m.destination) 74 } 75 76 func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { 77 if message.data.Len() <= maxPacketSize { 78 return []*udpMessage{message} 79 } 80 var fragments []*udpMessage 81 originPacket := message.data.Bytes() 82 udpMTU := maxPacketSize - message.headerSize() 83 for remaining := len(originPacket); remaining > 0; remaining -= udpMTU { 84 fragment := udpMessagePool.Get().(*udpMessage) 85 *fragment = *message 86 if remaining > udpMTU { 87 fragment.data = buf.As(originPacket[:udpMTU]) 88 originPacket = originPacket[udpMTU:] 89 } else { 90 fragment.data = buf.As(originPacket) 91 originPacket = nil 92 } 93 fragments = append(fragments, fragment) 94 } 95 fragmentTotal := uint16(len(fragments)) 96 for index, fragment := range fragments { 97 fragment.fragmentID = uint8(index) 98 fragment.fragmentTotal = uint8(fragmentTotal) 99 /*if index > 0 { 100 fragment.destination = "" 101 // not work in hysteria 102 }*/ 103 } 104 return fragments 105 } 106 107 type udpPacketConn struct { 108 ctx context.Context 109 cancel common.ContextCancelCauseFunc 110 sessionID uint32 111 quicConn quic.Connection 112 data chan *udpMessage 113 udpMTU int 114 udpMTUTime time.Time 115 packetId atomic.Uint32 116 closeOnce sync.Once 117 defragger *udpDefragger 118 onDestroy func() 119 } 120 121 func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn { 122 ctx, cancel := common.ContextWithCancelCause(ctx) 123 return &udpPacketConn{ 124 ctx: ctx, 125 cancel: cancel, 126 quicConn: quicConn, 127 data: make(chan *udpMessage, 64), 128 defragger: newUDPDefragger(), 129 onDestroy: onDestroy, 130 } 131 } 132 133 func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { 134 select { 135 case p := <-c.data: 136 buffer = p.data 137 destination = M.ParseSocksaddr(p.destination) 138 p.release() 139 return 140 case <-c.ctx.Done(): 141 return nil, M.Socksaddr{}, io.ErrClosedPipe 142 } 143 } 144 145 func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 146 select { 147 case p := <-c.data: 148 _, err = buffer.ReadOnceFrom(p.data) 149 destination = M.ParseSocksaddr(p.destination) 150 p.releaseMessage() 151 return 152 case <-c.ctx.Done(): 153 return M.Socksaddr{}, io.ErrClosedPipe 154 } 155 } 156 157 func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { 158 select { 159 case p := <-c.data: 160 _, err = newBuffer().ReadOnceFrom(p.data) 161 destination = M.ParseSocksaddr(p.destination) 162 p.releaseMessage() 163 return 164 case <-c.ctx.Done(): 165 return M.Socksaddr{}, io.ErrClosedPipe 166 } 167 } 168 169 func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 170 select { 171 case pkt := <-c.data: 172 n = copy(p, pkt.data.Bytes()) 173 destination := M.ParseSocksaddr(pkt.destination) 174 if destination.IsFqdn() { 175 addr = destination 176 } else { 177 addr = destination.UDPAddr() 178 } 179 pkt.releaseMessage() 180 return n, addr, nil 181 case <-c.ctx.Done(): 182 return 0, nil, io.ErrClosedPipe 183 } 184 } 185 186 func (c *udpPacketConn) needFragment() bool { 187 nowTime := time.Now() 188 if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second { 189 c.udpMTUTime = nowTime 190 return true 191 } 192 return false 193 } 194 195 func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 196 defer buffer.Release() 197 select { 198 case <-c.ctx.Done(): 199 return net.ErrClosed 200 default: 201 } 202 if buffer.Len() > 0xffff { 203 return quic.ErrMessageTooLarge(0xffff) 204 } 205 packetId := c.packetId.Add(1) 206 if packetId > math.MaxUint16 { 207 c.packetId.Store(0) 208 packetId = 0 209 } 210 message := udpMessagePool.Get().(*udpMessage) 211 *message = udpMessage{ 212 sessionID: c.sessionID, 213 packetID: uint16(packetId), 214 fragmentTotal: 1, 215 destination: destination.String(), 216 data: buffer, 217 } 218 defer message.releaseMessage() 219 var err error 220 if c.needFragment() && buffer.Len() > c.udpMTU { 221 err = c.writePackets(fragUDPMessage(message, c.udpMTU)) 222 } else { 223 err = c.writePacket(message) 224 } 225 if err == nil { 226 return nil 227 } 228 var tooLargeErr quic.ErrMessageTooLarge 229 if !errors.As(err, &tooLargeErr) { 230 return err 231 } 232 c.udpMTU = int(tooLargeErr) 233 c.udpMTUTime = time.Now() 234 return c.writePackets(fragUDPMessage(message, c.udpMTU)) 235 } 236 237 func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 238 select { 239 case <-c.ctx.Done(): 240 return 0, net.ErrClosed 241 default: 242 } 243 if len(p) > 0xffff { 244 return 0, quic.ErrMessageTooLarge(0xffff) 245 } 246 packetId := c.packetId.Add(1) 247 if packetId > math.MaxUint16 { 248 c.packetId.Store(0) 249 packetId = 0 250 } 251 message := udpMessagePool.Get().(*udpMessage) 252 *message = udpMessage{ 253 sessionID: c.sessionID, 254 packetID: uint16(packetId), 255 fragmentTotal: 1, 256 destination: addr.String(), 257 data: buf.As(p), 258 } 259 if c.needFragment() && len(p) > c.udpMTU { 260 err = c.writePackets(fragUDPMessage(message, c.udpMTU)) 261 if err == nil { 262 return len(p), nil 263 } 264 } else { 265 err = c.writePacket(message) 266 } 267 if err == nil { 268 return len(p), nil 269 } 270 var tooLargeErr quic.ErrMessageTooLarge 271 if !errors.As(err, &tooLargeErr) { 272 return 273 } 274 c.udpMTU = int(tooLargeErr) 275 c.udpMTUTime = time.Now() 276 err = c.writePackets(fragUDPMessage(message, c.udpMTU)) 277 if err == nil { 278 return len(p), nil 279 } 280 return 281 } 282 283 func (c *udpPacketConn) inputPacket(message *udpMessage) { 284 if message.fragmentTotal <= 1 { 285 select { 286 case c.data <- message: 287 default: 288 } 289 } else { 290 newMessage := c.defragger.feed(message) 291 if newMessage != nil { 292 select { 293 case c.data <- newMessage: 294 default: 295 } 296 } 297 } 298 } 299 300 func (c *udpPacketConn) writePackets(messages []*udpMessage) error { 301 defer releaseMessages(messages) 302 for _, message := range messages { 303 err := c.writePacket(message) 304 if err != nil { 305 return err 306 } 307 } 308 return nil 309 } 310 311 func (c *udpPacketConn) writePacket(message *udpMessage) error { 312 buffer := message.pack() 313 defer buffer.Release() 314 return c.quicConn.SendMessage(buffer.Bytes()) 315 } 316 317 func (c *udpPacketConn) Close() error { 318 c.closeOnce.Do(func() { 319 c.closeWithError(os.ErrClosed) 320 c.onDestroy() 321 }) 322 return nil 323 } 324 325 func (c *udpPacketConn) closeWithError(err error) { 326 c.cancel(err) 327 } 328 329 func (c *udpPacketConn) LocalAddr() net.Addr { 330 return c.quicConn.LocalAddr() 331 } 332 333 func (c *udpPacketConn) SetDeadline(t time.Time) error { 334 return os.ErrInvalid 335 } 336 337 func (c *udpPacketConn) SetReadDeadline(t time.Time) error { 338 return os.ErrInvalid 339 } 340 341 func (c *udpPacketConn) SetWriteDeadline(t time.Time) error { 342 return os.ErrInvalid 343 } 344 345 type udpDefragger struct { 346 packetMap *cache.LruCache[uint16, *packetItem] 347 } 348 349 func newUDPDefragger() *udpDefragger { 350 return &udpDefragger{ 351 packetMap: cache.New( 352 cache.WithAge[uint16, *packetItem](10), 353 cache.WithUpdateAgeOnGet[uint16, *packetItem](), 354 cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) { 355 releaseMessages(value.messages) 356 }), 357 ), 358 } 359 } 360 361 type packetItem struct { 362 access sync.Mutex 363 messages []*udpMessage 364 count uint8 365 } 366 367 func (d *udpDefragger) feed(m *udpMessage) *udpMessage { 368 if m.fragmentTotal <= 1 { 369 return m 370 } 371 if m.fragmentID >= m.fragmentTotal { 372 return nil 373 } 374 item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem) 375 item.access.Lock() 376 defer item.access.Unlock() 377 if int(m.fragmentTotal) != len(item.messages) { 378 releaseMessages(item.messages) 379 item.messages = make([]*udpMessage, m.fragmentTotal) 380 item.count = 1 381 item.messages[m.fragmentID] = m 382 return nil 383 } 384 if item.messages[m.fragmentID] != nil { 385 return nil 386 } 387 item.messages[m.fragmentID] = m 388 item.count++ 389 if int(item.count) != len(item.messages) { 390 return nil 391 } 392 newMessage := udpMessagePool.Get().(*udpMessage) 393 *newMessage = *item.messages[0] 394 var finalLength int 395 for _, message := range item.messages { 396 finalLength += message.data.Len() 397 } 398 if finalLength > 0 { 399 newMessage.data = buf.NewSize(finalLength) 400 for _, message := range item.messages { 401 newMessage.data.Write(message.data.Bytes()) 402 message.releaseMessage() 403 } 404 item.messages = nil 405 return newMessage 406 } 407 return nil 408 } 409 410 func newPacketItem() *packetItem { 411 return new(packetItem) 412 } 413 414 func decodeUDPMessage(message *udpMessage, data []byte) error { 415 reader := bytes.NewReader(data) 416 err := binary.Read(reader, binary.BigEndian, &message.sessionID) 417 if err != nil { 418 return err 419 } 420 err = binary.Read(reader, binary.BigEndian, &message.packetID) 421 if err != nil { 422 return err 423 } 424 err = binary.Read(reader, binary.BigEndian, &message.fragmentID) 425 if err != nil { 426 return err 427 } 428 err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal) 429 if err != nil { 430 return err 431 } 432 message.destination, err = protocol.ReadVString(reader) 433 if err != nil { 434 return err 435 } 436 message.data = buf.As(data[len(data)-reader.Len():]) 437 return nil 438 }