github.com/yaling888/clash@v1.53.0/transport/vmess/aead.go (about) 1 package vmess 2 3 import ( 4 "crypto/cipher" 5 "encoding/binary" 6 "errors" 7 "io" 8 "sync" 9 10 "github.com/yaling888/clash/common/pool" 11 ) 12 13 type aeadWriter struct { 14 io.Writer 15 cipher.AEAD 16 nonce [32]byte 17 count uint16 18 iv []byte 19 20 writeLock sync.Mutex 21 } 22 23 func newAEADWriter(w io.Writer, aead cipher.AEAD, iv []byte) *aeadWriter { 24 return &aeadWriter{Writer: w, AEAD: aead, iv: iv} 25 } 26 27 func (w *aeadWriter) Write(b []byte) (n int, err error) { 28 w.writeLock.Lock() 29 bufP := pool.GetNetBuf() 30 defer func() { 31 w.writeLock.Unlock() 32 pool.PutNetBuf(bufP) 33 }() 34 length := len(b) 35 for { 36 if length == 0 { 37 break 38 } 39 readLen := chunkSize - w.Overhead() 40 if length < readLen { 41 readLen = length 42 } 43 payloadBuf := (*bufP)[lenSize : lenSize+chunkSize-w.Overhead()] 44 copy(payloadBuf, b[n:n+readLen]) 45 46 binary.BigEndian.PutUint16((*bufP)[:lenSize], uint16(readLen+w.Overhead())) 47 binary.BigEndian.PutUint16(w.nonce[:2], w.count) 48 copy(w.nonce[2:], w.iv[2:12]) 49 50 w.Seal(payloadBuf[:0], w.nonce[:w.NonceSize()], payloadBuf[:readLen], nil) 51 w.count++ 52 53 _, err = w.Writer.Write((*bufP)[:lenSize+readLen+w.Overhead()]) 54 if err != nil { 55 break 56 } 57 n += readLen 58 length -= readLen 59 } 60 return 61 } 62 63 type aeadReader struct { 64 io.Reader 65 cipher.AEAD 66 nonce [32]byte 67 bufP *[]byte 68 offset int 69 iv []byte 70 sizeBuf []byte 71 count uint16 72 } 73 74 func newAEADReader(r io.Reader, aead cipher.AEAD, iv []byte) *aeadReader { 75 return &aeadReader{Reader: r, AEAD: aead, iv: iv, sizeBuf: make([]byte, lenSize)} 76 } 77 78 func (r *aeadReader) Read(b []byte) (int, error) { 79 if r.bufP != nil { 80 n := copy(b, (*r.bufP)[r.offset:]) 81 r.offset += n 82 if r.offset == len(*r.bufP) { 83 pool.PutNetBuf(r.bufP) 84 r.bufP = nil 85 } 86 return n, nil 87 } 88 89 _, err := io.ReadFull(r.Reader, r.sizeBuf) 90 if err != nil { 91 return 0, err 92 } 93 94 size := int(binary.BigEndian.Uint16(r.sizeBuf)) 95 if size > maxSize { 96 return 0, errors.New("buffer is larger than standard") 97 } 98 99 bufP := pool.GetNetBuf() 100 _, err = io.ReadFull(r.Reader, (*bufP)[:size]) 101 if err != nil { 102 pool.PutNetBuf(bufP) 103 return 0, err 104 } 105 106 binary.BigEndian.PutUint16(r.nonce[:2], r.count) 107 copy(r.nonce[2:], r.iv[2:12]) 108 109 _, err = r.Open((*bufP)[:0], r.nonce[:r.NonceSize()], (*bufP)[:size], nil) 110 r.count++ 111 if err != nil { 112 pool.PutNetBuf(bufP) 113 return 0, err 114 } 115 realLen := size - r.Overhead() 116 n := copy(b, (*bufP)[:realLen]) 117 if len(b) >= realLen { 118 pool.PutNetBuf(bufP) 119 return n, nil 120 } 121 122 *bufP = (*bufP)[:realLen] 123 r.offset = n 124 r.bufP = bufP 125 return n, nil 126 }