github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/quic/frag.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"math"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/Asutorufa/yuhaiin/pkg/log"
    11  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    12  	"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
    13  	"github.com/quic-go/quic-go"
    14  )
    15  
    16  // https://github.com/quic-go/quic-go/blob/49e588a6a9905446e49d382d78115e6e960b1144/internal/protocol/params.go#L134
    17  // the minium depend on DataLenPresent, the minium need minus 3
    18  // see: https://github.com/quic-go/quic-go/blob/1e874896cd39adc02663be4d77ade701b333df5a/internal/wire/datagram_frame.go#L62
    19  var MaxDatagramFrameSize int64 = 1200 - 3
    20  
    21  type Frag struct {
    22  	SplitID  atomic.Uint64
    23  	mergeMap syncmap.SyncMap[uint64, *MergeFrag]
    24  }
    25  
    26  type MergeFrag struct {
    27  	Count    uint32
    28  	Total    uint32
    29  	TotalLen uint32
    30  	Data     [][]byte
    31  	time     time.Time
    32  }
    33  
    34  func (f *Frag) collect(ctx context.Context) {
    35  	timer := time.NewTimer(60 * time.Second)
    36  	defer timer.Stop()
    37  
    38  	for {
    39  		select {
    40  		case <-ctx.Done():
    41  			return
    42  
    43  		case <-timer.C:
    44  			now := time.Now()
    45  			f.mergeMap.Range(func(id uint64, v *MergeFrag) bool {
    46  				if now.Sub(v.time) > 30*time.Second {
    47  					f.mergeMap.Delete(id)
    48  				}
    49  				return true
    50  			})
    51  		}
    52  	}
    53  }
    54  
    55  func (f *Frag) Merge(buf []byte) *pool.Buffer {
    56  	fh := fragFrame(buf)
    57  
    58  	if fh.Type() == FragmentTypeSingle {
    59  		return pool.NewBuffer(fh.Payload())
    60  	}
    61  
    62  	total := fh.Total()
    63  	index := fh.Current()
    64  	id := fh.ID()
    65  
    66  	mf, ok := f.mergeMap.Load(id)
    67  
    68  	if fh.Type() != FragmentTypeSplit || total == 0 || index >= total || (ok && uint32(total) != mf.Total) {
    69  		f.mergeMap.Delete(id)
    70  		return nil
    71  	}
    72  
    73  	if !ok {
    74  		mf, _ = f.mergeMap.LoadOrStore(id, &MergeFrag{
    75  			Data:  make([][]byte, total),
    76  			Total: uint32(total),
    77  			time:  time.Now(),
    78  		})
    79  	}
    80  
    81  	current := atomic.AddUint32(&mf.Count, 1)
    82  	atomic.AddUint32(&mf.TotalLen, uint32(len(fh.Payload())))
    83  	mf.Data[index] = fh.Payload()
    84  
    85  	if current == mf.Total {
    86  		f.mergeMap.Delete(id)
    87  
    88  		buf := pool.GetBytesWriter(mf.TotalLen)
    89  		for _, v := range mf.Data {
    90  			_, _ = buf.Write(v)
    91  		}
    92  		return buf
    93  	}
    94  
    95  	return nil
    96  }
    97  
    98  func (f *Frag) Split(buf []byte, maxSize int) pool.MultipleBuffer {
    99  	headerSize := 1 + 8 + 1 + 1
   100  
   101  	if maxSize <= headerSize {
   102  		return nil
   103  	}
   104  
   105  	if len(buf) < maxSize-1 {
   106  		return pool.MultipleBuffer{NewFragFrameBytesBuffer(FragmentTypeSingle, 0, 1, 0, buf)}
   107  	}
   108  
   109  	maxSize = maxSize - headerSize
   110  
   111  	frames := len(buf) / maxSize
   112  	if len(buf)%maxSize != 0 {
   113  		frames++
   114  	}
   115  
   116  	if frames > math.MaxUint8 {
   117  		log.Error("too many frames", "frames", frames)
   118  		return nil
   119  	}
   120  
   121  	var frameArray pool.MultipleBuffer = make(pool.MultipleBuffer, 0, frames)
   122  
   123  	id := f.SplitID.Add(1)
   124  
   125  	for i := 0; i < frames; i++ {
   126  		var frame []byte
   127  		if i == frames-1 {
   128  			frame = buf[i*maxSize:]
   129  		} else {
   130  			frame = buf[i*maxSize : (i+1)*maxSize]
   131  		}
   132  
   133  		frameArray = append(frameArray, NewFragFrameBytesBuffer(FragmentTypeSplit, id, uint8(frames), uint8(i), frame))
   134  	}
   135  
   136  	return frameArray
   137  }
   138  
   139  type ConnectionPacketConn struct {
   140  	conn quic.Connection
   141  	frag *Frag
   142  }
   143  
   144  func NewConnectionPacketConn(conn quic.Connection) *ConnectionPacketConn {
   145  	frag := &Frag{}
   146  	go frag.collect(conn.Context())
   147  	return &ConnectionPacketConn{conn: conn, frag: frag}
   148  }
   149  
   150  func (c *ConnectionPacketConn) Context() context.Context {
   151  	return c.conn.Context()
   152  }
   153  
   154  func (c *ConnectionPacketConn) Receive(ctx context.Context) (uint64, *pool.Buffer, error) {
   155  _retry:
   156  	data, err := c.conn.ReceiveDatagram(ctx)
   157  	if err != nil {
   158  		return 0, nil, err
   159  	}
   160  
   161  	buf := c.frag.Merge(data)
   162  	if buf == nil {
   163  		goto _retry
   164  	}
   165  
   166  	id := binary.BigEndian.Uint64(buf.Discard(8))
   167  
   168  	return id, buf, nil
   169  }
   170  
   171  func (c *ConnectionPacketConn) Write(b []byte, id uint64) error {
   172  	buf := pool.GetBytesWriter(8 + len(b))
   173  	defer buf.Free()
   174  
   175  	buf.WriteUint64(id)
   176  	_, _ = buf.Write(b)
   177  
   178  	buffers := c.frag.Split(buf.Bytes(), int(MaxDatagramFrameSize))
   179  	defer buffers.Free()
   180  
   181  	for _, v := range buffers {
   182  		if err := c.conn.SendDatagram(v.Bytes()); err != nil {
   183  			return err
   184  		}
   185  	}
   186  
   187  	return nil
   188  }
   189  
   190  type FragType uint8
   191  
   192  const (
   193  	FragmentTypeSplit FragType = iota + 1
   194  	FragmentTypeSingle
   195  )
   196  
   197  type fragFrame []byte
   198  
   199  /*
   200  every frame max length: 1200 - 3
   201  
   202  Single Frame
   203  max payload length: 1200 - 3 - 1
   204  +-------+~~~~~~~~~~~~~~+
   205  | type  |    payload   |
   206  +-------+~~~~~~~~~~~~~~+
   207  |  1    |    variable  |
   208  +-------+~~~~~~~~~~~~~~+
   209  
   210  Split Frame
   211  max payload length: 1200 - 3 - 1 - 8 - 1 - 1
   212  +------+------------------+---------+---------+~~~~~~~~~~~~~~+
   213  | type |        id        |  total  | current |    payload   |
   214  +------+------------------+---------+---------+~~~~~~~~~~~~~~+
   215  |  1   |      8 bytes     | 1 byte  | 1 byte  |    variable  |
   216  +------+------------------+---------+---------+~~~~~~~~~~~~~~+
   217  */
   218  func NewFragFrameBytesBuffer(t FragType, id uint64, total, current uint8, payload []byte) *pool.Buffer {
   219  	var buf *pool.Buffer
   220  	if t == FragmentTypeSingle {
   221  		buf = pool.GetBytesWriter(1 + len(payload))
   222  	} else {
   223  		buf = pool.GetBytesWriter(1 + 8 + 1 + 1 + len(payload))
   224  	}
   225  	putFragFrame(buf, t, id, total, current, payload)
   226  	return buf
   227  }
   228  
   229  func putFragFrame(buf *pool.Buffer, t FragType, id uint64, total, current uint8, payload []byte) {
   230  	buf.WriteByte(byte(t))
   231  
   232  	if t == FragmentTypeSingle {
   233  		buf.Write(payload)
   234  		return
   235  	}
   236  
   237  	buf.WriteUint64(id)
   238  	buf.WriteByte(total)
   239  	buf.WriteByte(current)
   240  	buf.Write(payload)
   241  }
   242  
   243  func (f fragFrame) Type() FragType {
   244  	if len(f) < 1 {
   245  		return 0
   246  	}
   247  
   248  	return FragType(f[0])
   249  }
   250  
   251  func (f fragFrame) ID() uint64 {
   252  	if len(f) < 1+8 {
   253  		return 0
   254  	}
   255  
   256  	return binary.BigEndian.Uint64(f[1:])
   257  }
   258  
   259  func (f fragFrame) Total() uint8 {
   260  	if len(f) < 1+8+1 {
   261  		return 0
   262  	}
   263  
   264  	return f[1+8+1-1]
   265  }
   266  
   267  func (f fragFrame) Current() uint8 {
   268  	if len(f) < 1+8+1+1 {
   269  		return 0
   270  	}
   271  
   272  	return f[1+8+1+1-1]
   273  }
   274  
   275  func (f fragFrame) Payload() []byte {
   276  	if f.Type() == FragmentTypeSingle {
   277  		return f[1:]
   278  	}
   279  
   280  	if len(f) < 1+8+1+1 {
   281  		return nil
   282  	}
   283  
   284  	return f[1+8+1+1:]
   285  }