github.com/metacubex/mihomo@v1.18.5/transport/shadowsocks/shadowaead/stream.go (about) 1 package shadowaead 2 3 import ( 4 "crypto/cipher" 5 "crypto/rand" 6 "errors" 7 "io" 8 "net" 9 10 "github.com/metacubex/mihomo/common/pool" 11 ) 12 13 const ( 14 // payloadSizeMask is the maximum size of payload in bytes. 15 payloadSizeMask = 0x3FFF // 16*1024 - 1 16 bufSize = 17 * 1024 // >= 2+aead.Overhead()+payloadSizeMask+aead.Overhead() 17 ) 18 19 var ErrZeroChunk = errors.New("zero chunk") 20 21 type Writer struct { 22 io.Writer 23 cipher.AEAD 24 nonce [32]byte // should be sufficient for most nonce sizes 25 } 26 27 // NewWriter wraps an io.Writer with authenticated encryption. 28 func NewWriter(w io.Writer, aead cipher.AEAD) *Writer { return &Writer{Writer: w, AEAD: aead} } 29 30 // Write encrypts p and writes to the embedded io.Writer. 31 func (w *Writer) Write(p []byte) (n int, err error) { 32 buf := pool.Get(bufSize) 33 defer pool.Put(buf) 34 nonce := w.nonce[:w.NonceSize()] 35 tag := w.Overhead() 36 off := 2 + tag 37 38 // compatible with snell 39 if len(p) == 0 { 40 buf = buf[:off] 41 buf[0], buf[1] = byte(0), byte(0) 42 w.Seal(buf[:0], nonce, buf[:2], nil) 43 increment(nonce) 44 _, err = w.Writer.Write(buf) 45 return 46 } 47 48 for nr := 0; n < len(p) && err == nil; n += nr { 49 nr = payloadSizeMask 50 if n+nr > len(p) { 51 nr = len(p) - n 52 } 53 buf = buf[:off+nr+tag] 54 buf[0], buf[1] = byte(nr>>8), byte(nr) // big-endian payload size 55 w.Seal(buf[:0], nonce, buf[:2], nil) 56 increment(nonce) 57 w.Seal(buf[:off], nonce, p[n:n+nr], nil) 58 increment(nonce) 59 _, err = w.Writer.Write(buf) 60 } 61 return 62 } 63 64 // ReadFrom reads from the given io.Reader until EOF or error, encrypts and 65 // writes to the embedded io.Writer. Returns number of bytes read from r and 66 // any error encountered. 67 func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) { 68 buf := pool.Get(bufSize) 69 defer pool.Put(buf) 70 nonce := w.nonce[:w.NonceSize()] 71 tag := w.Overhead() 72 off := 2 + tag 73 for { 74 nr, er := r.Read(buf[off : off+payloadSizeMask]) 75 n += int64(nr) 76 buf[0], buf[1] = byte(nr>>8), byte(nr) 77 w.Seal(buf[:0], nonce, buf[:2], nil) 78 increment(nonce) 79 w.Seal(buf[:off], nonce, buf[off:off+nr], nil) 80 increment(nonce) 81 if _, ew := w.Writer.Write(buf[:off+nr+tag]); ew != nil { 82 err = ew 83 return 84 } 85 if er != nil { 86 if er != io.EOF { // ignore EOF as per io.ReaderFrom contract 87 err = er 88 } 89 return 90 } 91 } 92 } 93 94 type Reader struct { 95 io.Reader 96 cipher.AEAD 97 nonce [32]byte // should be sufficient for most nonce sizes 98 buf []byte // to be put back into bufPool 99 off int // offset to unconsumed part of buf 100 } 101 102 // NewReader wraps an io.Reader with authenticated decryption. 103 func NewReader(r io.Reader, aead cipher.AEAD) *Reader { return &Reader{Reader: r, AEAD: aead} } 104 105 // Read and decrypt a record into p. len(p) >= max payload size + AEAD overhead. 106 func (r *Reader) read(p []byte) (int, error) { 107 nonce := r.nonce[:r.NonceSize()] 108 tag := r.Overhead() 109 110 // decrypt payload size 111 p = p[:2+tag] 112 if _, err := io.ReadFull(r.Reader, p); err != nil { 113 return 0, err 114 } 115 _, err := r.Open(p[:0], nonce, p, nil) 116 increment(nonce) 117 if err != nil { 118 return 0, err 119 } 120 121 // decrypt payload 122 size := (int(p[0])<<8 + int(p[1])) & payloadSizeMask 123 if size == 0 { 124 return 0, ErrZeroChunk 125 } 126 127 p = p[:size+tag] 128 if _, err := io.ReadFull(r.Reader, p); err != nil { 129 return 0, err 130 } 131 _, err = r.Open(p[:0], nonce, p, nil) 132 increment(nonce) 133 if err != nil { 134 return 0, err 135 } 136 return size, nil 137 } 138 139 // Read reads from the embedded io.Reader, decrypts and writes to p. 140 func (r *Reader) Read(p []byte) (int, error) { 141 if r.buf == nil { 142 if len(p) >= payloadSizeMask+r.Overhead() { 143 return r.read(p) 144 } 145 b := pool.Get(bufSize) 146 n, err := r.read(b) 147 if err != nil { 148 return 0, err 149 } 150 r.buf = b[:n] 151 r.off = 0 152 } 153 154 n := copy(p, r.buf[r.off:]) 155 r.off += n 156 if r.off == len(r.buf) { 157 pool.Put(r.buf[:cap(r.buf)]) 158 r.buf = nil 159 } 160 return n, nil 161 } 162 163 // WriteTo reads from the embedded io.Reader, decrypts and writes to w until 164 // there's no more data to write or when an error occurs. Return number of 165 // bytes written to w and any error encountered. 166 func (r *Reader) WriteTo(w io.Writer) (n int64, err error) { 167 if r.buf == nil { 168 r.buf = pool.Get(bufSize) 169 r.off = len(r.buf) 170 } 171 172 for { 173 for r.off < len(r.buf) { 174 nw, ew := w.Write(r.buf[r.off:]) 175 r.off += nw 176 n += int64(nw) 177 if ew != nil { 178 if r.off == len(r.buf) { 179 pool.Put(r.buf[:cap(r.buf)]) 180 r.buf = nil 181 } 182 err = ew 183 return 184 } 185 } 186 187 nr, er := r.read(r.buf) 188 if er != nil { 189 if er != io.EOF { 190 err = er 191 } 192 return 193 } 194 r.buf = r.buf[:nr] 195 r.off = 0 196 } 197 } 198 199 // increment little-endian encoded unsigned integer b. Wrap around on overflow. 200 func increment(b []byte) { 201 for i := range b { 202 b[i]++ 203 if b[i] != 0 { 204 return 205 } 206 } 207 } 208 209 type Conn struct { 210 net.Conn 211 Cipher 212 r *Reader 213 w *Writer 214 } 215 216 // NewConn wraps a stream-oriented net.Conn with cipher. 217 func NewConn(c net.Conn, ciph Cipher) *Conn { return &Conn{Conn: c, Cipher: ciph} } 218 219 func (c *Conn) initReader() error { 220 salt := make([]byte, c.SaltSize()) 221 if _, err := io.ReadFull(c.Conn, salt); err != nil { 222 return err 223 } 224 225 aead, err := c.Decrypter(salt) 226 if err != nil { 227 return err 228 } 229 230 c.r = NewReader(c.Conn, aead) 231 return nil 232 } 233 234 func (c *Conn) Read(b []byte) (int, error) { 235 if c.r == nil { 236 if err := c.initReader(); err != nil { 237 return 0, err 238 } 239 } 240 return c.r.Read(b) 241 } 242 243 func (c *Conn) WriteTo(w io.Writer) (int64, error) { 244 if c.r == nil { 245 if err := c.initReader(); err != nil { 246 return 0, err 247 } 248 } 249 return c.r.WriteTo(w) 250 } 251 252 func (c *Conn) initWriter() error { 253 salt := make([]byte, c.SaltSize()) 254 if _, err := rand.Read(salt); err != nil { 255 return err 256 } 257 aead, err := c.Encrypter(salt) 258 if err != nil { 259 return err 260 } 261 _, err = c.Conn.Write(salt) 262 if err != nil { 263 return err 264 } 265 c.w = NewWriter(c.Conn, aead) 266 return nil 267 } 268 269 func (c *Conn) Write(b []byte) (int, error) { 270 if c.w == nil { 271 if err := c.initWriter(); err != nil { 272 return 0, err 273 } 274 } 275 return c.w.Write(b) 276 } 277 278 func (c *Conn) ReadFrom(r io.Reader) (int64, error) { 279 if c.w == nil { 280 if err := c.initWriter(); err != nil { 281 return 0, err 282 } 283 } 284 return c.w.ReadFrom(r) 285 }