github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/quic/frag.go (about) 1 package quic 2 3 import ( 4 "context" 5 "encoding/binary" 6 "math" 7 "sync/atomic" 8 "time" 9 10 "github.com/Asutorufa/yuhaiin/pkg/log" 11 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 12 "github.com/Asutorufa/yuhaiin/pkg/utils/syncmap" 13 "github.com/quic-go/quic-go" 14 ) 15 16 // https://github.com/quic-go/quic-go/blob/49e588a6a9905446e49d382d78115e6e960b1144/internal/protocol/params.go#L134 17 // the minium depend on DataLenPresent, the minium need minus 3 18 // see: https://github.com/quic-go/quic-go/blob/1e874896cd39adc02663be4d77ade701b333df5a/internal/wire/datagram_frame.go#L62 19 var MaxDatagramFrameSize int64 = 1200 - 3 20 21 type Frag struct { 22 SplitID atomic.Uint64 23 mergeMap syncmap.SyncMap[uint64, *MergeFrag] 24 } 25 26 type MergeFrag struct { 27 Count uint32 28 Total uint32 29 TotalLen uint32 30 Data [][]byte 31 time time.Time 32 } 33 34 func (f *Frag) collect(ctx context.Context) { 35 timer := time.NewTimer(60 * time.Second) 36 defer timer.Stop() 37 38 for { 39 select { 40 case <-ctx.Done(): 41 return 42 43 case <-timer.C: 44 now := time.Now() 45 f.mergeMap.Range(func(id uint64, v *MergeFrag) bool { 46 if now.Sub(v.time) > 30*time.Second { 47 f.mergeMap.Delete(id) 48 } 49 return true 50 }) 51 } 52 } 53 } 54 55 func (f *Frag) Merge(buf []byte) *pool.Buffer { 56 fh := fragFrame(buf) 57 58 if fh.Type() == FragmentTypeSingle { 59 return pool.NewBuffer(fh.Payload()) 60 } 61 62 total := fh.Total() 63 index := fh.Current() 64 id := fh.ID() 65 66 mf, ok := f.mergeMap.Load(id) 67 68 if fh.Type() != FragmentTypeSplit || total == 0 || index >= total || (ok && uint32(total) != mf.Total) { 69 f.mergeMap.Delete(id) 70 return nil 71 } 72 73 if !ok { 74 mf, _ = f.mergeMap.LoadOrStore(id, &MergeFrag{ 75 Data: make([][]byte, total), 76 Total: uint32(total), 77 time: time.Now(), 78 }) 79 } 80 81 current := atomic.AddUint32(&mf.Count, 1) 82 atomic.AddUint32(&mf.TotalLen, uint32(len(fh.Payload()))) 83 mf.Data[index] = fh.Payload() 84 85 if current == mf.Total { 86 f.mergeMap.Delete(id) 87 88 buf := pool.GetBytesWriter(mf.TotalLen) 89 for _, v := range mf.Data { 90 _, _ = buf.Write(v) 91 } 92 return buf 93 } 94 95 return nil 96 } 97 98 func (f *Frag) Split(buf []byte, maxSize int) pool.MultipleBuffer { 99 headerSize := 1 + 8 + 1 + 1 100 101 if maxSize <= headerSize { 102 return nil 103 } 104 105 if len(buf) < maxSize-1 { 106 return pool.MultipleBuffer{NewFragFrameBytesBuffer(FragmentTypeSingle, 0, 1, 0, buf)} 107 } 108 109 maxSize = maxSize - headerSize 110 111 frames := len(buf) / maxSize 112 if len(buf)%maxSize != 0 { 113 frames++ 114 } 115 116 if frames > math.MaxUint8 { 117 log.Error("too many frames", "frames", frames) 118 return nil 119 } 120 121 var frameArray pool.MultipleBuffer = make(pool.MultipleBuffer, 0, frames) 122 123 id := f.SplitID.Add(1) 124 125 for i := 0; i < frames; i++ { 126 var frame []byte 127 if i == frames-1 { 128 frame = buf[i*maxSize:] 129 } else { 130 frame = buf[i*maxSize : (i+1)*maxSize] 131 } 132 133 frameArray = append(frameArray, NewFragFrameBytesBuffer(FragmentTypeSplit, id, uint8(frames), uint8(i), frame)) 134 } 135 136 return frameArray 137 } 138 139 type ConnectionPacketConn struct { 140 conn quic.Connection 141 frag *Frag 142 } 143 144 func NewConnectionPacketConn(conn quic.Connection) *ConnectionPacketConn { 145 frag := &Frag{} 146 go frag.collect(conn.Context()) 147 return &ConnectionPacketConn{conn: conn, frag: frag} 148 } 149 150 func (c *ConnectionPacketConn) Context() context.Context { 151 return c.conn.Context() 152 } 153 154 func (c *ConnectionPacketConn) Receive(ctx context.Context) (uint64, *pool.Buffer, error) { 155 _retry: 156 data, err := c.conn.ReceiveDatagram(ctx) 157 if err != nil { 158 return 0, nil, err 159 } 160 161 buf := c.frag.Merge(data) 162 if buf == nil { 163 goto _retry 164 } 165 166 id := binary.BigEndian.Uint64(buf.Discard(8)) 167 168 return id, buf, nil 169 } 170 171 func (c *ConnectionPacketConn) Write(b []byte, id uint64) error { 172 buf := pool.GetBytesWriter(8 + len(b)) 173 defer buf.Free() 174 175 buf.WriteUint64(id) 176 _, _ = buf.Write(b) 177 178 buffers := c.frag.Split(buf.Bytes(), int(MaxDatagramFrameSize)) 179 defer buffers.Free() 180 181 for _, v := range buffers { 182 if err := c.conn.SendDatagram(v.Bytes()); err != nil { 183 return err 184 } 185 } 186 187 return nil 188 } 189 190 type FragType uint8 191 192 const ( 193 FragmentTypeSplit FragType = iota + 1 194 FragmentTypeSingle 195 ) 196 197 type fragFrame []byte 198 199 /* 200 every frame max length: 1200 - 3 201 202 Single Frame 203 max payload length: 1200 - 3 - 1 204 +-------+~~~~~~~~~~~~~~+ 205 | type | payload | 206 +-------+~~~~~~~~~~~~~~+ 207 | 1 | variable | 208 +-------+~~~~~~~~~~~~~~+ 209 210 Split Frame 211 max payload length: 1200 - 3 - 1 - 8 - 1 - 1 212 +------+------------------+---------+---------+~~~~~~~~~~~~~~+ 213 | type | id | total | current | payload | 214 +------+------------------+---------+---------+~~~~~~~~~~~~~~+ 215 | 1 | 8 bytes | 1 byte | 1 byte | variable | 216 +------+------------------+---------+---------+~~~~~~~~~~~~~~+ 217 */ 218 func NewFragFrameBytesBuffer(t FragType, id uint64, total, current uint8, payload []byte) *pool.Buffer { 219 var buf *pool.Buffer 220 if t == FragmentTypeSingle { 221 buf = pool.GetBytesWriter(1 + len(payload)) 222 } else { 223 buf = pool.GetBytesWriter(1 + 8 + 1 + 1 + len(payload)) 224 } 225 putFragFrame(buf, t, id, total, current, payload) 226 return buf 227 } 228 229 func putFragFrame(buf *pool.Buffer, t FragType, id uint64, total, current uint8, payload []byte) { 230 buf.WriteByte(byte(t)) 231 232 if t == FragmentTypeSingle { 233 buf.Write(payload) 234 return 235 } 236 237 buf.WriteUint64(id) 238 buf.WriteByte(total) 239 buf.WriteByte(current) 240 buf.Write(payload) 241 } 242 243 func (f fragFrame) Type() FragType { 244 if len(f) < 1 { 245 return 0 246 } 247 248 return FragType(f[0]) 249 } 250 251 func (f fragFrame) ID() uint64 { 252 if len(f) < 1+8 { 253 return 0 254 } 255 256 return binary.BigEndian.Uint64(f[1:]) 257 } 258 259 func (f fragFrame) Total() uint8 { 260 if len(f) < 1+8+1 { 261 return 0 262 } 263 264 return f[1+8+1-1] 265 } 266 267 func (f fragFrame) Current() uint8 { 268 if len(f) < 1+8+1+1 { 269 return 0 270 } 271 272 return f[1+8+1+1-1] 273 } 274 275 func (f fragFrame) Payload() []byte { 276 if f.Type() == FragmentTypeSingle { 277 return f[1:] 278 } 279 280 if len(f) < 1+8+1+1 { 281 return nil 282 } 283 284 return f[1+8+1+1:] 285 }