github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/hysteria2/packet.go (about)

     1  package hysteria2
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/binary"
     7  	"errors"
     8  	"github.com/sagernet/quic-go"
     9  	"io"
    10  	"math"
    11  	"net"
    12  	"os"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/inazumav/sing-box/transport/hysteria2/internal/protocol"
    17  	"github.com/sagernet/quic-go/quicvarint"
    18  	"github.com/sagernet/sing/common"
    19  	"github.com/sagernet/sing/common/atomic"
    20  	"github.com/sagernet/sing/common/buf"
    21  	"github.com/sagernet/sing/common/cache"
    22  	M "github.com/sagernet/sing/common/metadata"
    23  )
    24  
    25  var udpMessagePool = sync.Pool{
    26  	New: func() interface{} {
    27  		return new(udpMessage)
    28  	},
    29  }
    30  
    31  func releaseMessages(messages []*udpMessage) {
    32  	for _, message := range messages {
    33  		if message != nil {
    34  			*message = udpMessage{}
    35  			udpMessagePool.Put(message)
    36  		}
    37  	}
    38  }
    39  
    40  type udpMessage struct {
    41  	sessionID     uint32
    42  	packetID      uint16
    43  	fragmentID    uint8
    44  	fragmentTotal uint8
    45  	destination   string
    46  	data          *buf.Buffer
    47  }
    48  
    49  func (m *udpMessage) release() {
    50  	*m = udpMessage{}
    51  	udpMessagePool.Put(m)
    52  }
    53  
    54  func (m *udpMessage) releaseMessage() {
    55  	m.data.Release()
    56  	m.release()
    57  }
    58  
    59  func (m *udpMessage) pack() *buf.Buffer {
    60  	buffer := buf.NewSize(m.headerSize() + m.data.Len())
    61  	common.Must(
    62  		binary.Write(buffer, binary.BigEndian, m.sessionID),
    63  		binary.Write(buffer, binary.BigEndian, m.packetID),
    64  		binary.Write(buffer, binary.BigEndian, m.fragmentID),
    65  		binary.Write(buffer, binary.BigEndian, m.fragmentTotal),
    66  		protocol.WriteVString(buffer, m.destination),
    67  		common.Error(buffer.Write(m.data.Bytes())),
    68  	)
    69  	return buffer
    70  }
    71  
    72  func (m *udpMessage) headerSize() int {
    73  	return 8 + int(quicvarint.Len(uint64(len(m.destination)))) + len(m.destination)
    74  }
    75  
    76  func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage {
    77  	if message.data.Len() <= maxPacketSize {
    78  		return []*udpMessage{message}
    79  	}
    80  	var fragments []*udpMessage
    81  	originPacket := message.data.Bytes()
    82  	udpMTU := maxPacketSize - message.headerSize()
    83  	for remaining := len(originPacket); remaining > 0; remaining -= udpMTU {
    84  		fragment := udpMessagePool.Get().(*udpMessage)
    85  		*fragment = *message
    86  		if remaining > udpMTU {
    87  			fragment.data = buf.As(originPacket[:udpMTU])
    88  			originPacket = originPacket[udpMTU:]
    89  		} else {
    90  			fragment.data = buf.As(originPacket)
    91  			originPacket = nil
    92  		}
    93  		fragments = append(fragments, fragment)
    94  	}
    95  	fragmentTotal := uint16(len(fragments))
    96  	for index, fragment := range fragments {
    97  		fragment.fragmentID = uint8(index)
    98  		fragment.fragmentTotal = uint8(fragmentTotal)
    99  		/*if index > 0 {
   100  			fragment.destination = ""
   101  			// not work in hysteria
   102  		}*/
   103  	}
   104  	return fragments
   105  }
   106  
   107  type udpPacketConn struct {
   108  	ctx        context.Context
   109  	cancel     common.ContextCancelCauseFunc
   110  	sessionID  uint32
   111  	quicConn   quic.Connection
   112  	data       chan *udpMessage
   113  	udpMTU     int
   114  	udpMTUTime time.Time
   115  	packetId   atomic.Uint32
   116  	closeOnce  sync.Once
   117  	defragger  *udpDefragger
   118  	onDestroy  func()
   119  }
   120  
   121  func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn {
   122  	ctx, cancel := common.ContextWithCancelCause(ctx)
   123  	return &udpPacketConn{
   124  		ctx:       ctx,
   125  		cancel:    cancel,
   126  		quicConn:  quicConn,
   127  		data:      make(chan *udpMessage, 64),
   128  		defragger: newUDPDefragger(),
   129  		onDestroy: onDestroy,
   130  	}
   131  }
   132  
   133  func (c *udpPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
   134  	select {
   135  	case p := <-c.data:
   136  		buffer = p.data
   137  		destination = M.ParseSocksaddr(p.destination)
   138  		p.release()
   139  		return
   140  	case <-c.ctx.Done():
   141  		return nil, M.Socksaddr{}, io.ErrClosedPipe
   142  	}
   143  }
   144  
   145  func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   146  	select {
   147  	case p := <-c.data:
   148  		_, err = buffer.ReadOnceFrom(p.data)
   149  		destination = M.ParseSocksaddr(p.destination)
   150  		p.releaseMessage()
   151  		return
   152  	case <-c.ctx.Done():
   153  		return M.Socksaddr{}, io.ErrClosedPipe
   154  	}
   155  }
   156  
   157  func (c *udpPacketConn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
   158  	select {
   159  	case p := <-c.data:
   160  		_, err = newBuffer().ReadOnceFrom(p.data)
   161  		destination = M.ParseSocksaddr(p.destination)
   162  		p.releaseMessage()
   163  		return
   164  	case <-c.ctx.Done():
   165  		return M.Socksaddr{}, io.ErrClosedPipe
   166  	}
   167  }
   168  
   169  func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   170  	select {
   171  	case pkt := <-c.data:
   172  		n = copy(p, pkt.data.Bytes())
   173  		destination := M.ParseSocksaddr(pkt.destination)
   174  		if destination.IsFqdn() {
   175  			addr = destination
   176  		} else {
   177  			addr = destination.UDPAddr()
   178  		}
   179  		pkt.releaseMessage()
   180  		return n, addr, nil
   181  	case <-c.ctx.Done():
   182  		return 0, nil, io.ErrClosedPipe
   183  	}
   184  }
   185  
   186  func (c *udpPacketConn) needFragment() bool {
   187  	nowTime := time.Now()
   188  	if c.udpMTU > 0 && nowTime.Sub(c.udpMTUTime) < 5*time.Second {
   189  		c.udpMTUTime = nowTime
   190  		return true
   191  	}
   192  	return false
   193  }
   194  
   195  func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   196  	defer buffer.Release()
   197  	select {
   198  	case <-c.ctx.Done():
   199  		return net.ErrClosed
   200  	default:
   201  	}
   202  	if buffer.Len() > 0xffff {
   203  		return quic.ErrMessageTooLarge(0xffff)
   204  	}
   205  	packetId := c.packetId.Add(1)
   206  	if packetId > math.MaxUint16 {
   207  		c.packetId.Store(0)
   208  		packetId = 0
   209  	}
   210  	message := udpMessagePool.Get().(*udpMessage)
   211  	*message = udpMessage{
   212  		sessionID:     c.sessionID,
   213  		packetID:      uint16(packetId),
   214  		fragmentTotal: 1,
   215  		destination:   destination.String(),
   216  		data:          buffer,
   217  	}
   218  	defer message.releaseMessage()
   219  	var err error
   220  	if c.needFragment() && buffer.Len() > c.udpMTU {
   221  		err = c.writePackets(fragUDPMessage(message, c.udpMTU))
   222  	} else {
   223  		err = c.writePacket(message)
   224  	}
   225  	if err == nil {
   226  		return nil
   227  	}
   228  	var tooLargeErr quic.ErrMessageTooLarge
   229  	if !errors.As(err, &tooLargeErr) {
   230  		return err
   231  	}
   232  	c.udpMTU = int(tooLargeErr)
   233  	c.udpMTUTime = time.Now()
   234  	return c.writePackets(fragUDPMessage(message, c.udpMTU))
   235  }
   236  
   237  func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   238  	select {
   239  	case <-c.ctx.Done():
   240  		return 0, net.ErrClosed
   241  	default:
   242  	}
   243  	if len(p) > 0xffff {
   244  		return 0, quic.ErrMessageTooLarge(0xffff)
   245  	}
   246  	packetId := c.packetId.Add(1)
   247  	if packetId > math.MaxUint16 {
   248  		c.packetId.Store(0)
   249  		packetId = 0
   250  	}
   251  	message := udpMessagePool.Get().(*udpMessage)
   252  	*message = udpMessage{
   253  		sessionID:     c.sessionID,
   254  		packetID:      uint16(packetId),
   255  		fragmentTotal: 1,
   256  		destination:   addr.String(),
   257  		data:          buf.As(p),
   258  	}
   259  	if c.needFragment() && len(p) > c.udpMTU {
   260  		err = c.writePackets(fragUDPMessage(message, c.udpMTU))
   261  		if err == nil {
   262  			return len(p), nil
   263  		}
   264  	} else {
   265  		err = c.writePacket(message)
   266  	}
   267  	if err == nil {
   268  		return len(p), nil
   269  	}
   270  	var tooLargeErr quic.ErrMessageTooLarge
   271  	if !errors.As(err, &tooLargeErr) {
   272  		return
   273  	}
   274  	c.udpMTU = int(tooLargeErr)
   275  	c.udpMTUTime = time.Now()
   276  	err = c.writePackets(fragUDPMessage(message, c.udpMTU))
   277  	if err == nil {
   278  		return len(p), nil
   279  	}
   280  	return
   281  }
   282  
   283  func (c *udpPacketConn) inputPacket(message *udpMessage) {
   284  	if message.fragmentTotal <= 1 {
   285  		select {
   286  		case c.data <- message:
   287  		default:
   288  		}
   289  	} else {
   290  		newMessage := c.defragger.feed(message)
   291  		if newMessage != nil {
   292  			select {
   293  			case c.data <- newMessage:
   294  			default:
   295  			}
   296  		}
   297  	}
   298  }
   299  
   300  func (c *udpPacketConn) writePackets(messages []*udpMessage) error {
   301  	defer releaseMessages(messages)
   302  	for _, message := range messages {
   303  		err := c.writePacket(message)
   304  		if err != nil {
   305  			return err
   306  		}
   307  	}
   308  	return nil
   309  }
   310  
   311  func (c *udpPacketConn) writePacket(message *udpMessage) error {
   312  	buffer := message.pack()
   313  	defer buffer.Release()
   314  	return c.quicConn.SendMessage(buffer.Bytes())
   315  }
   316  
   317  func (c *udpPacketConn) Close() error {
   318  	c.closeOnce.Do(func() {
   319  		c.closeWithError(os.ErrClosed)
   320  		c.onDestroy()
   321  	})
   322  	return nil
   323  }
   324  
   325  func (c *udpPacketConn) closeWithError(err error) {
   326  	c.cancel(err)
   327  }
   328  
   329  func (c *udpPacketConn) LocalAddr() net.Addr {
   330  	return c.quicConn.LocalAddr()
   331  }
   332  
   333  func (c *udpPacketConn) SetDeadline(t time.Time) error {
   334  	return os.ErrInvalid
   335  }
   336  
   337  func (c *udpPacketConn) SetReadDeadline(t time.Time) error {
   338  	return os.ErrInvalid
   339  }
   340  
   341  func (c *udpPacketConn) SetWriteDeadline(t time.Time) error {
   342  	return os.ErrInvalid
   343  }
   344  
   345  type udpDefragger struct {
   346  	packetMap *cache.LruCache[uint16, *packetItem]
   347  }
   348  
   349  func newUDPDefragger() *udpDefragger {
   350  	return &udpDefragger{
   351  		packetMap: cache.New(
   352  			cache.WithAge[uint16, *packetItem](10),
   353  			cache.WithUpdateAgeOnGet[uint16, *packetItem](),
   354  			cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) {
   355  				releaseMessages(value.messages)
   356  			}),
   357  		),
   358  	}
   359  }
   360  
   361  type packetItem struct {
   362  	access   sync.Mutex
   363  	messages []*udpMessage
   364  	count    uint8
   365  }
   366  
   367  func (d *udpDefragger) feed(m *udpMessage) *udpMessage {
   368  	if m.fragmentTotal <= 1 {
   369  		return m
   370  	}
   371  	if m.fragmentID >= m.fragmentTotal {
   372  		return nil
   373  	}
   374  	item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem)
   375  	item.access.Lock()
   376  	defer item.access.Unlock()
   377  	if int(m.fragmentTotal) != len(item.messages) {
   378  		releaseMessages(item.messages)
   379  		item.messages = make([]*udpMessage, m.fragmentTotal)
   380  		item.count = 1
   381  		item.messages[m.fragmentID] = m
   382  		return nil
   383  	}
   384  	if item.messages[m.fragmentID] != nil {
   385  		return nil
   386  	}
   387  	item.messages[m.fragmentID] = m
   388  	item.count++
   389  	if int(item.count) != len(item.messages) {
   390  		return nil
   391  	}
   392  	newMessage := udpMessagePool.Get().(*udpMessage)
   393  	*newMessage = *item.messages[0]
   394  	var finalLength int
   395  	for _, message := range item.messages {
   396  		finalLength += message.data.Len()
   397  	}
   398  	if finalLength > 0 {
   399  		newMessage.data = buf.NewSize(finalLength)
   400  		for _, message := range item.messages {
   401  			newMessage.data.Write(message.data.Bytes())
   402  			message.releaseMessage()
   403  		}
   404  		item.messages = nil
   405  		return newMessage
   406  	}
   407  	return nil
   408  }
   409  
   410  func newPacketItem() *packetItem {
   411  	return new(packetItem)
   412  }
   413  
   414  func decodeUDPMessage(message *udpMessage, data []byte) error {
   415  	reader := bytes.NewReader(data)
   416  	err := binary.Read(reader, binary.BigEndian, &message.sessionID)
   417  	if err != nil {
   418  		return err
   419  	}
   420  	err = binary.Read(reader, binary.BigEndian, &message.packetID)
   421  	if err != nil {
   422  		return err
   423  	}
   424  	err = binary.Read(reader, binary.BigEndian, &message.fragmentID)
   425  	if err != nil {
   426  		return err
   427  	}
   428  	err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal)
   429  	if err != nil {
   430  		return err
   431  	}
   432  	message.destination, err = protocol.ReadVString(reader)
   433  	if err != nil {
   434  		return err
   435  	}
   436  	message.data = buf.As(data[len(data)-reader.Len():])
   437  	return nil
   438  }