github.com/chwjbn/xclash@v0.2.0/transport/vmess/aead.go (about) 1 package vmess 2 3 import ( 4 "crypto/cipher" 5 "encoding/binary" 6 "errors" 7 "io" 8 "sync" 9 10 "github.com/chwjbn/xclash/common/pool" 11 ) 12 13 type aeadWriter struct { 14 io.Writer 15 cipher.AEAD 16 nonce [32]byte 17 count uint16 18 iv []byte 19 20 writeLock sync.Mutex 21 } 22 23 func newAEADWriter(w io.Writer, aead cipher.AEAD, iv []byte) *aeadWriter { 24 return &aeadWriter{Writer: w, AEAD: aead, iv: iv} 25 } 26 27 func (w *aeadWriter) Write(b []byte) (n int, err error) { 28 w.writeLock.Lock() 29 buf := pool.Get(pool.RelayBufferSize) 30 defer func() { 31 w.writeLock.Unlock() 32 pool.Put(buf) 33 }() 34 length := len(b) 35 for { 36 if length == 0 { 37 break 38 } 39 readLen := chunkSize - w.Overhead() 40 if length < readLen { 41 readLen = length 42 } 43 payloadBuf := buf[lenSize : lenSize+chunkSize-w.Overhead()] 44 copy(payloadBuf, b[n:n+readLen]) 45 46 binary.BigEndian.PutUint16(buf[:lenSize], uint16(readLen+w.Overhead())) 47 binary.BigEndian.PutUint16(w.nonce[:2], w.count) 48 copy(w.nonce[2:], w.iv[2:12]) 49 50 w.Seal(payloadBuf[:0], w.nonce[:w.NonceSize()], payloadBuf[:readLen], nil) 51 w.count++ 52 53 _, err = w.Writer.Write(buf[:lenSize+readLen+w.Overhead()]) 54 if err != nil { 55 break 56 } 57 n += readLen 58 length -= readLen 59 } 60 return 61 } 62 63 type aeadReader struct { 64 io.Reader 65 cipher.AEAD 66 nonce [32]byte 67 buf []byte 68 offset int 69 iv []byte 70 sizeBuf []byte 71 count uint16 72 } 73 74 func newAEADReader(r io.Reader, aead cipher.AEAD, iv []byte) *aeadReader { 75 return &aeadReader{Reader: r, AEAD: aead, iv: iv, sizeBuf: make([]byte, lenSize)} 76 } 77 78 func (r *aeadReader) Read(b []byte) (int, error) { 79 if r.buf != nil { 80 n := copy(b, r.buf[r.offset:]) 81 r.offset += n 82 if r.offset == len(r.buf) { 83 pool.Put(r.buf) 84 r.buf = nil 85 } 86 return n, nil 87 } 88 89 _, err := io.ReadFull(r.Reader, r.sizeBuf) 90 if err != nil { 91 return 0, err 92 } 93 94 size := int(binary.BigEndian.Uint16(r.sizeBuf)) 95 if size > maxSize { 96 return 0, errors.New("buffer is larger than standard") 97 } 98 99 buf := pool.Get(size) 100 _, err = io.ReadFull(r.Reader, buf[:size]) 101 if err != nil { 102 pool.Put(buf) 103 return 0, err 104 } 105 106 binary.BigEndian.PutUint16(r.nonce[:2], r.count) 107 copy(r.nonce[2:], r.iv[2:12]) 108 109 _, err = r.Open(buf[:0], r.nonce[:r.NonceSize()], buf[:size], nil) 110 r.count++ 111 if err != nil { 112 return 0, err 113 } 114 realLen := size - r.Overhead() 115 n := copy(b, buf[:realLen]) 116 if len(b) >= realLen { 117 pool.Put(buf) 118 return n, nil 119 } 120 121 r.offset = n 122 r.buf = buf[:realLen] 123 return n, nil 124 }