github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/vmess/aead.go (about) 1 package vmess 2 3 import ( 4 "bytes" 5 "crypto/cipher" 6 "encoding/binary" 7 "io" 8 9 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 10 ) 11 12 var _ io.WriteCloser = &aeadWriter{} 13 14 type aeadWriter struct { 15 io.Writer 16 cipher.AEAD 17 nonce []byte 18 buf [lenSize + maxChunkSize]byte 19 count uint16 20 iv []byte 21 } 22 23 // AEADWriter returns a aead writer 24 func AEADWriter(w io.Writer, aead cipher.AEAD, iv []byte) writer { 25 return &aeadWriter{ 26 Writer: w, 27 AEAD: aead, 28 nonce: make([]byte, aead.NonceSize()), 29 count: 0, 30 iv: iv, 31 } 32 } 33 34 func (w *aeadWriter) Close() error { return nil } 35 36 func (w *aeadWriter) Write(b []byte) (int, error) { 37 n, err := w.ReadFrom(bytes.NewBuffer(b)) 38 return int(n), err 39 } 40 41 func (w *aeadWriter) ReadFrom(r io.Reader) (n int64, err error) { 42 buf := w.buf[:] 43 for { 44 payloadBuf := w.buf[lenSize : lenSize+defaultChunkSize-w.Overhead()] 45 46 nr, er := r.Read(payloadBuf) 47 if nr > 0 { 48 n += int64(nr) 49 buf = buf[:lenSize+nr+w.Overhead()] 50 payloadBuf = payloadBuf[:nr] 51 binary.BigEndian.PutUint16(w.buf[:lenSize], uint16(nr+w.Overhead())) 52 53 binary.BigEndian.PutUint16(w.nonce[:2], w.count) 54 copy(w.nonce[2:], w.iv[2:12]) 55 56 w.Seal(payloadBuf[:0], w.nonce[:w.NonceSize()], payloadBuf, nil) 57 w.count++ 58 59 _, ew := w.Writer.Write(buf) 60 if ew != nil { 61 err = ew 62 break 63 } 64 } 65 66 if er != nil { 67 if er != io.EOF { // ignore EOF as per io.ReaderFrom contract 68 err = er 69 } 70 break 71 } 72 } 73 74 return n, err 75 } 76 77 var _ io.ReadCloser = &aeadReader{} 78 79 type aeadReader struct { 80 io.Reader 81 cipher.AEAD 82 count uint16 83 iv []byte 84 85 decrypted bytes.Buffer 86 } 87 88 // AEADReader returns a aead reader 89 func AEADReader(r io.Reader, aead cipher.AEAD, iv []byte) io.ReadCloser { 90 return &aeadReader{ 91 Reader: r, 92 AEAD: aead, 93 count: 0, 94 iv: iv, 95 } 96 } 97 98 func (r *aeadReader) Close() error { return nil } 99 100 func (r *aeadReader) Read(b []byte) (int, error) { 101 if r.decrypted.Len() > 0 { 102 return r.decrypted.Read(b) 103 } 104 105 lb := pool.GetBytes(r.NonceSize()) 106 defer pool.PutBytes(lb) 107 108 // get length 109 _, err := io.ReadFull(r.Reader, lb[:lenSize]) 110 if err != nil { 111 return 0, err 112 } 113 114 // if length == 0, then this is the end 115 l := binary.BigEndian.Uint16(lb[:lenSize]) 116 if l == 0 { 117 return 0, nil 118 } 119 120 buf := pool.GetBytes(int(l)) 121 defer pool.PutBytes(buf) 122 // get payload 123 _, err = io.ReadFull(r.Reader, buf[:l]) 124 if err != nil { 125 return 0, err 126 } 127 128 binary.BigEndian.PutUint16(lb[:2], r.count) 129 copy(lb[2:], r.iv[2:12]) 130 131 _, err = r.Open(buf[:0], lb[:r.NonceSize()], buf[:l], nil) 132 r.count++ 133 if err != nil { 134 return 0, err 135 } 136 137 r.decrypted.Write(buf[:int(l)-r.Overhead()]) 138 return r.decrypted.Read(b) 139 }