github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/vmess/chunk.go (about) 1 package vmess 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "io" 7 "net" 8 9 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 10 "github.com/Asutorufa/yuhaiin/pkg/utils/relay" 11 ) 12 13 type writer interface { 14 io.WriteCloser 15 io.ReaderFrom 16 } 17 18 type connWriter struct { 19 net.Conn 20 } 21 22 func (c *connWriter) ReadFrom(r io.Reader) (int64, error) { 23 return relay.Copy(c.Conn, r) 24 } 25 26 const ( 27 lenSize = 2 28 maxChunkSize = 1 << 14 // 16384 29 defaultChunkSize = 1 << 13 // 8192 30 ) 31 32 var _ writer = &aeadWriter{} 33 34 type chunkedWriter struct { 35 io.Writer 36 buf [lenSize + maxChunkSize]byte 37 } 38 39 // ChunkedWriter returns a chunked writer 40 func ChunkedWriter(w io.Writer) writer { return &chunkedWriter{Writer: w} } 41 42 func (w *chunkedWriter) Close() error { return nil } 43 44 func (w *chunkedWriter) Write(b []byte) (int, error) { 45 n, err := w.ReadFrom(bytes.NewBuffer(b)) 46 return int(n), err 47 } 48 49 func (w *chunkedWriter) ReadFrom(r io.Reader) (n int64, err error) { 50 for { 51 nr, er := r.Read(w.buf[lenSize : lenSize+defaultChunkSize]) 52 if nr > 0 { 53 n += int64(nr) 54 binary.BigEndian.PutUint16(w.buf[:lenSize], uint16(nr)) 55 _, err = w.Writer.Write(w.buf[:lenSize+nr]) 56 if err != nil { 57 // err = ew 58 break 59 } 60 } 61 62 if er != nil { 63 if er != io.EOF { // ignore EOF as per io.ReaderFrom contract 64 err = er 65 } 66 break 67 } 68 } 69 70 return n, err 71 } 72 73 type chunkedReader struct { 74 io.Reader 75 leftBytes int 76 } 77 78 // ChunkedReader returns a chunked reader 79 func ChunkedReader(r io.Reader) io.ReadCloser { return &chunkedReader{Reader: r} } 80 func (r *chunkedReader) Close() error { return nil } 81 func (r *chunkedReader) Read(b []byte) (int, error) { 82 if r.leftBytes <= 0 { 83 buf := pool.GetBytes(lenSize) 84 defer pool.PutBytes(buf) 85 86 // get length 87 _, err := io.ReadFull(r.Reader, buf[:lenSize]) 88 if err != nil { 89 return 0, err 90 } 91 r.leftBytes = int(binary.BigEndian.Uint16(buf[:lenSize])) 92 93 // if length == 0, then this is the end 94 if r.leftBytes <= 0 { 95 return 0, nil 96 } 97 } 98 99 m, err := r.Reader.Read(b) 100 if err != nil { 101 return 0, err 102 } 103 r.leftBytes -= m 104 105 return m, nil 106 }