github.com/yaling888/clash@v1.53.0/transport/shadowsocks/shadowaead/stream.go (about) 1 package shadowaead 2 3 import ( 4 "crypto/cipher" 5 "crypto/rand" 6 "errors" 7 "io" 8 "net" 9 "sync" 10 ) 11 12 const ( 13 // payloadSizeMask is the maximum size of payload in bytes. 14 payloadSizeMask = 0x3FFF // 16*1024 - 1 15 bufSize = 17 * 1024 // >= 2+aead.Overhead()+payloadSizeMask+aead.Overhead() 16 ) 17 18 var ErrZeroChunk = errors.New("zero chunk") 19 20 type Writer struct { 21 io.Writer 22 cipher.AEAD 23 nonce [32]byte // should be sufficient for most nonce sizes 24 } 25 26 // NewWriter wraps an io.Writer with authenticated encryption. 27 func NewWriter(w io.Writer, aead cipher.AEAD) *Writer { return &Writer{Writer: w, AEAD: aead} } 28 29 var bufPool = sync.Pool{ 30 New: func() any { 31 b := make([]byte, bufSize) 32 return &b 33 }, 34 } 35 36 // Write encrypts p and writes to the embedded io.Writer. 37 func (w *Writer) Write(p []byte) (n int, err error) { 38 bufP := bufPool.Get().(*[]byte) 39 defer bufPool.Put(bufP) 40 nonce := w.nonce[:w.NonceSize()] 41 tag := w.Overhead() 42 off := 2 + tag 43 44 // compatible with snell 45 if len(p) == 0 { 46 (*bufP)[0], (*bufP)[1] = byte(0), byte(0) 47 w.Seal((*bufP)[:0], nonce, (*bufP)[:2], nil) 48 increment(nonce) 49 _, err = w.Writer.Write((*bufP)[:off]) 50 return 51 } 52 53 for nr := 0; n < len(p) && err == nil; n += nr { 54 nr = payloadSizeMask 55 if n+nr > len(p) { 56 nr = len(p) - n 57 } 58 (*bufP)[0], (*bufP)[1] = byte(nr>>8), byte(nr) // big-endian payload size 59 w.Seal((*bufP)[:0], nonce, (*bufP)[:2], nil) 60 increment(nonce) 61 w.Seal((*bufP)[:off], nonce, p[n:n+nr], nil) 62 increment(nonce) 63 _, err = w.Writer.Write((*bufP)[:off+nr+tag]) 64 } 65 return 66 } 67 68 // ReadFrom reads from the given io.Reader until EOF or error, encrypts and 69 // writes to the embedded io.Writer. Returns number of bytes read from r and 70 // any error encountered. 71 func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) { 72 bufP := bufPool.Get().(*[]byte) 73 defer bufPool.Put(bufP) 74 nonce := w.nonce[:w.NonceSize()] 75 tag := w.Overhead() 76 off := 2 + tag 77 for { 78 nr, er := r.Read((*bufP)[off : off+payloadSizeMask]) 79 n += int64(nr) 80 (*bufP)[0], (*bufP)[1] = byte(nr>>8), byte(nr) 81 w.Seal((*bufP)[:0], nonce, (*bufP)[:2], nil) 82 increment(nonce) 83 w.Seal((*bufP)[:off], nonce, (*bufP)[off:off+nr], nil) 84 increment(nonce) 85 if _, ew := w.Writer.Write((*bufP)[:off+nr+tag]); ew != nil { 86 err = ew 87 return 88 } 89 if er != nil { 90 if er != io.EOF { // ignore EOF as per io.ReaderFrom contract 91 err = er 92 } 93 return 94 } 95 } 96 } 97 98 type Reader struct { 99 io.Reader 100 cipher.AEAD 101 nonce [32]byte // should be sufficient for most nonce sizes 102 bufP *[]byte // to be put back into bufPool 103 off int // offset to unconsumed part of buf 104 } 105 106 // NewReader wraps an io.Reader with authenticated decryption. 107 func NewReader(r io.Reader, aead cipher.AEAD) *Reader { return &Reader{Reader: r, AEAD: aead} } 108 109 // Read and decrypt a record into p. len(p) >= max payload size + AEAD overhead. 110 func (r *Reader) read(p []byte) (int, error) { 111 nonce := r.nonce[:r.NonceSize()] 112 tag := r.Overhead() 113 114 // decrypt payload size 115 p = p[:2+tag] 116 if _, err := io.ReadFull(r.Reader, p); err != nil { 117 return 0, err 118 } 119 _, err := r.Open(p[:0], nonce, p, nil) 120 increment(nonce) 121 if err != nil { 122 return 0, err 123 } 124 125 // decrypt payload 126 size := (int(p[0])<<8 + int(p[1])) & payloadSizeMask 127 if size == 0 { 128 return 0, ErrZeroChunk 129 } 130 131 p = p[:size+tag] 132 if _, err := io.ReadFull(r.Reader, p); err != nil { 133 return 0, err 134 } 135 _, err = r.Open(p[:0], nonce, p, nil) 136 increment(nonce) 137 if err != nil { 138 return 0, err 139 } 140 return size, nil 141 } 142 143 // Read reads from the embedded io.Reader, decrypts and writes to p. 144 func (r *Reader) Read(p []byte) (int, error) { 145 if r.bufP == nil { 146 if len(p) >= payloadSizeMask+r.Overhead() { 147 return r.read(p) 148 } 149 bp := bufPool.Get().(*[]byte) 150 n, err := r.read(*bp) 151 if err != nil { 152 return 0, err 153 } 154 *bp = (*bp)[:n] 155 r.bufP = bp 156 r.off = 0 157 } 158 159 n := copy(p, (*r.bufP)[r.off:]) 160 r.off += n 161 if r.off == len(*r.bufP) { 162 *r.bufP = (*r.bufP)[:bufSize] 163 bufPool.Put(r.bufP) 164 r.bufP = nil 165 } 166 return n, nil 167 } 168 169 // WriteTo reads from the embedded io.Reader, decrypts and writes to w until 170 // there's no more data to write or when an error occurs. Return number of 171 // bytes written to w and any error encountered. 172 func (r *Reader) WriteTo(w io.Writer) (n int64, err error) { 173 if r.bufP == nil { 174 r.bufP = bufPool.Get().(*[]byte) 175 r.off = len(*r.bufP) 176 } 177 178 for { 179 for r.off < len(*r.bufP) { 180 nw, ew := w.Write((*r.bufP)[r.off:]) 181 r.off += nw 182 n += int64(nw) 183 if ew != nil { 184 if r.off == len(*r.bufP) { 185 *r.bufP = (*r.bufP)[:bufSize] 186 bufPool.Put(r.bufP) 187 r.bufP = nil 188 } 189 err = ew 190 return 191 } 192 } 193 194 nr, er := r.read(*r.bufP) 195 if er != nil { 196 if er != io.EOF { 197 err = er 198 } 199 return 200 } 201 *r.bufP = (*r.bufP)[:nr] 202 r.off = 0 203 } 204 } 205 206 // increment little-endian encoded unsigned integer b. Wrap around on overflow. 207 func increment(b []byte) { 208 for i := range b { 209 b[i]++ 210 if b[i] != 0 { 211 return 212 } 213 } 214 } 215 216 type Conn struct { 217 net.Conn 218 Cipher 219 r *Reader 220 w *Writer 221 } 222 223 // NewConn wraps a stream-oriented net.Conn with cipher. 224 func NewConn(c net.Conn, ciph Cipher) *Conn { return &Conn{Conn: c, Cipher: ciph} } 225 226 func (c *Conn) initReader() error { 227 salt := make([]byte, c.SaltSize()) 228 if _, err := io.ReadFull(c.Conn, salt); err != nil { 229 return err 230 } 231 232 aead, err := c.Decrypter(salt) 233 if err != nil { 234 return err 235 } 236 237 c.r = NewReader(c.Conn, aead) 238 return nil 239 } 240 241 func (c *Conn) Read(b []byte) (int, error) { 242 if c.r == nil { 243 if err := c.initReader(); err != nil { 244 return 0, err 245 } 246 } 247 return c.r.Read(b) 248 } 249 250 func (c *Conn) WriteTo(w io.Writer) (int64, error) { 251 if c.r == nil { 252 if err := c.initReader(); err != nil { 253 return 0, err 254 } 255 } 256 return c.r.WriteTo(w) 257 } 258 259 func (c *Conn) initWriter() error { 260 salt := make([]byte, c.SaltSize()) 261 if _, err := rand.Read(salt); err != nil { 262 return err 263 } 264 aead, err := c.Encrypter(salt) 265 if err != nil { 266 return err 267 } 268 _, err = c.Conn.Write(salt) 269 if err != nil { 270 return err 271 } 272 c.w = NewWriter(c.Conn, aead) 273 return nil 274 } 275 276 func (c *Conn) Write(b []byte) (int, error) { 277 if c.w == nil { 278 if err := c.initWriter(); err != nil { 279 return 0, err 280 } 281 } 282 return c.w.Write(b) 283 } 284 285 func (c *Conn) ReadFrom(r io.Reader) (int64, error) { 286 if c.w == nil { 287 if err := c.initWriter(); err != nil { 288 return 0, err 289 } 290 } 291 return c.w.ReadFrom(r) 292 }