github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/yuubinsya/crypto/aead.go (about) 1 package crypto 2 3 import ( 4 "crypto/cipher" 5 "encoding/binary" 6 "io" 7 "net" 8 "sync" 9 10 "github.com/Asutorufa/yuhaiin/pkg/net/nat" 11 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 12 "golang.org/x/crypto/chacha20poly1305" 13 ) 14 15 var Chacha20poly1305 = chacha20poly1305Aead{} 16 17 type chacha20poly1305Aead struct{} 18 19 func (chacha20poly1305Aead) New(key []byte) (cipher.AEAD, error) { return chacha20poly1305.New(key) } 20 func (chacha20poly1305Aead) KeySize() int { return chacha20poly1305.KeySize } 21 func (chacha20poly1305Aead) NonceSize() int { return chacha20poly1305.NonceSize } 22 func (chacha20poly1305Aead) Name() []byte { return []byte("chacha20poly1305-key") } 23 24 type streamConn struct { 25 net.Conn 26 r io.Reader 27 w io.Writer 28 } 29 30 func (c *streamConn) Read(b []byte) (int, error) { return c.r.Read(b) } 31 func (c *streamConn) Write(b []byte) (int, error) { return c.w.Write(b) } 32 33 // NewConn wraps a stream-oriented net.Conn with cipher. 34 func NewConn(c net.Conn, rnonce, wnonce []byte, rciph, wciph cipher.AEAD) net.Conn { 35 return &streamConn{ 36 Conn: c, 37 r: NewReader(c, rnonce, rciph, nat.MaxSegmentSize), 38 w: NewWriter(c, wnonce, wciph, nat.MaxSegmentSize), 39 } 40 } 41 42 type writer struct { 43 io.Writer 44 cipher.AEAD 45 nonce []byte 46 maxPayloadSize int 47 48 mu sync.Mutex 49 } 50 51 // NewWriter wraps an io.Writer with AEAD encryption. 52 53 func NewWriter(w io.Writer, nonce []byte, aead cipher.AEAD, maxPayloadSize int) *writer { 54 return &writer{ 55 Writer: w, 56 AEAD: aead, 57 nonce: nonce, 58 maxPayloadSize: maxPayloadSize, 59 } 60 } 61 62 func (w *writer) Write(p []byte) (n int, err error) { 63 if len(p) == 0 { 64 return 65 } 66 67 buf := pool.GetBytes(2 + w.AEAD.Overhead() + w.maxPayloadSize + w.AEAD.Overhead()) 68 defer pool.PutBytes(buf) 69 70 for pLen := len(p); pLen > 0; { 71 var data []byte 72 if pLen > w.maxPayloadSize { 73 data = p[:w.maxPayloadSize] 74 p = p[w.maxPayloadSize:] 75 pLen -= w.maxPayloadSize 76 } else { 77 data = p 78 pLen = 0 79 } 80 binary.BigEndian.PutUint16(buf[:2], uint16(len(data))) 81 w.mu.Lock() 82 w.Seal(buf[:0], w.nonce, buf[:2], nil) 83 increment(w.nonce) 84 offset := w.Overhead() + 2 85 packet := w.Seal(buf[offset:offset], w.nonce, data, nil) 86 increment(w.nonce) 87 _, err = w.Writer.Write(buf[:offset+len(packet)]) 88 w.mu.Unlock() 89 if err != nil { 90 return 91 } 92 n += len(data) 93 } 94 95 return 96 } 97 98 type reader struct { 99 io.Reader 100 cipher.AEAD 101 nonce []byte 102 buf []byte 103 leftover []byte 104 105 mu sync.Mutex 106 } 107 108 func NewReader(r io.Reader, nonce []byte, aead cipher.AEAD, maxPayloadSize int) *reader { 109 return &reader{ 110 Reader: r, 111 AEAD: aead, 112 buf: make([]byte, maxPayloadSize+aead.Overhead()), 113 nonce: nonce, 114 } 115 } 116 117 // read and decrypt a record into the internal buffer. Return decrypted payload length and any error encountered. 118 func (r *reader) read() (int, error) { 119 // decrypt payload size 120 buf := r.buf[:2+r.Overhead()] 121 _, err := io.ReadFull(r.Reader, buf) 122 if err != nil { 123 return 0, err 124 } 125 126 _, err = r.Open(buf[:0], r.nonce, buf, nil) 127 increment(r.nonce) 128 if err != nil { 129 return 0, err 130 } 131 132 size := int(binary.BigEndian.Uint16(buf[:2])) 133 134 // decrypt payload 135 buf = r.buf[:size+r.Overhead()] 136 _, err = io.ReadFull(r.Reader, buf) 137 if err != nil { 138 return 0, err 139 } 140 141 _, err = r.Open(buf[:0], r.nonce, buf, nil) 142 increment(r.nonce) 143 if err != nil { 144 return 0, err 145 } 146 147 return size, nil 148 } 149 150 // Read reads from the embedded io.Reader, decrypts and writes to b. 151 func (r *reader) Read(b []byte) (int, error) { 152 r.mu.Lock() 153 defer r.mu.Unlock() 154 155 // copy decrypted bytes (if any) from previous record first 156 if len(r.leftover) > 0 { 157 n := copy(b, r.leftover) 158 r.leftover = r.leftover[n:] 159 return n, nil 160 } 161 162 n, err := r.read() 163 164 m := copy(b, r.buf[:n]) 165 if m < n { // insufficient len(b), keep leftover for next read 166 r.leftover = r.buf[m:n] 167 } 168 return m, err 169 } 170 171 // increment little-endian encoded unsigned integer b. Wrap around on overflow. 172 func increment(b []byte) { 173 for i := range b { 174 b[i]++ 175 if b[i] != 0 { 176 return 177 } 178 } 179 }