github.com/sagernet/sing-shadowsocks@v0.2.6/shadowaead/aead.go (about) 1 package shadowaead 2 3 import ( 4 "crypto/cipher" 5 "encoding/binary" 6 "io" 7 "sync" 8 9 "github.com/sagernet/sing/common/buf" 10 ) 11 12 // https://shadowsocks.org/en/wiki/AEAD-Ciphers.html 13 const ( 14 MaxPacketSize = 16*1024 - 1 15 PacketLengthBufferSize = 2 16 ) 17 18 const ( 19 // Overhead 20 // crypto/cipher.gcmTagSize 21 // golang.org/x/crypto/chacha20poly1305.Overhead 22 Overhead = 16 23 ) 24 25 type Reader struct { 26 upstream io.Reader 27 cipher cipher.AEAD 28 buffer []byte 29 nonce []byte 30 index int 31 cached int 32 } 33 34 func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reader { 35 return &Reader{ 36 upstream: upstream, 37 cipher: cipher, 38 buffer: make([]byte, maxPacketSize+Overhead), 39 nonce: make([]byte, cipher.NonceSize()), 40 } 41 } 42 43 func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce []byte) *Reader { 44 return &Reader{ 45 upstream: upstream, 46 cipher: cipher, 47 buffer: buffer, 48 nonce: nonce, 49 } 50 } 51 52 func (r *Reader) Upstream() any { 53 return r.upstream 54 } 55 56 func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) { 57 if r.cached > 0 { 58 writeN, writeErr := writer.Write(r.buffer[r.index : r.index+r.cached]) 59 if writeErr != nil { 60 return int64(writeN), writeErr 61 } 62 n += int64(writeN) 63 } 64 for { 65 start := PacketLengthBufferSize + Overhead 66 _, err = io.ReadFull(r.upstream, r.buffer[:start]) 67 if err != nil { 68 return 69 } 70 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) 71 if err != nil { 72 return 73 } 74 increaseNonce(r.nonce) 75 length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) 76 end := length + Overhead 77 _, err = io.ReadFull(r.upstream, r.buffer[:end]) 78 if err != nil { 79 return 80 } 81 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) 82 if err != nil { 83 return 84 } 85 increaseNonce(r.nonce) 86 writeN, writeErr := writer.Write(r.buffer[:length]) 87 if writeErr != nil { 88 return int64(writeN), writeErr 89 } 90 n += int64(writeN) 91 } 92 } 93 94 func (r *Reader) readInternal() (err error) { 95 start := PacketLengthBufferSize + Overhead 96 _, err = io.ReadFull(r.upstream, r.buffer[:start]) 97 if err != nil { 98 return err 99 } 100 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) 101 if err != nil { 102 return err 103 } 104 increaseNonce(r.nonce) 105 length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) 106 end := length + Overhead 107 _, err = io.ReadFull(r.upstream, r.buffer[:end]) 108 if err != nil { 109 return err 110 } 111 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) 112 if err != nil { 113 return err 114 } 115 increaseNonce(r.nonce) 116 r.cached = length 117 r.index = 0 118 return nil 119 } 120 121 func (r *Reader) ReadByte() (byte, error) { 122 if r.cached == 0 { 123 err := r.readInternal() 124 if err != nil { 125 return 0, err 126 } 127 } 128 index := r.index 129 r.index++ 130 r.cached-- 131 return r.buffer[index], nil 132 } 133 134 func (r *Reader) Read(b []byte) (n int, err error) { 135 if r.cached > 0 { 136 n = copy(b, r.buffer[r.index:r.index+r.cached]) 137 r.cached -= n 138 r.index += n 139 return 140 } 141 start := PacketLengthBufferSize + Overhead 142 _, err = io.ReadFull(r.upstream, r.buffer[:start]) 143 if err != nil { 144 return 0, err 145 } 146 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) 147 if err != nil { 148 return 0, err 149 } 150 increaseNonce(r.nonce) 151 length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) 152 end := length + Overhead 153 154 if len(b) >= end { 155 data := b[:end] 156 _, err = io.ReadFull(r.upstream, data) 157 if err != nil { 158 return 0, err 159 } 160 _, err = r.cipher.Open(b[:0], r.nonce, data, nil) 161 if err != nil { 162 return 0, err 163 } 164 increaseNonce(r.nonce) 165 return length, nil 166 } else { 167 _, err = io.ReadFull(r.upstream, r.buffer[:end]) 168 if err != nil { 169 return 0, err 170 } 171 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) 172 if err != nil { 173 return 0, err 174 } 175 increaseNonce(r.nonce) 176 n = copy(b, r.buffer[:length]) 177 r.cached = length - n 178 r.index = n 179 return 180 } 181 } 182 183 func (r *Reader) Discard(n int) error { 184 for { 185 if r.cached >= n { 186 r.cached -= n 187 r.index += n 188 return nil 189 } else if r.cached > 0 { 190 n -= r.cached 191 r.cached = 0 192 r.index = 0 193 } 194 err := r.readInternal() 195 if err != nil { 196 return err 197 } 198 } 199 } 200 201 func (r *Reader) Buffer() *buf.Buffer { 202 buffer := buf.With(r.buffer) 203 buffer.Resize(r.index, r.cached) 204 return buffer 205 } 206 207 func (r *Reader) Cached() int { 208 return r.cached 209 } 210 211 func (r *Reader) CachedSlice() []byte { 212 return r.buffer[r.index : r.index+r.cached] 213 } 214 215 func (r *Reader) ReadWithLengthChunk(lengthChunk []byte) error { 216 _, err := r.cipher.Open(r.buffer[:0], r.nonce, lengthChunk, nil) 217 if err != nil { 218 return err 219 } 220 increaseNonce(r.nonce) 221 length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) 222 end := length + Overhead 223 _, err = io.ReadFull(r.upstream, r.buffer[:end]) 224 if err != nil { 225 return err 226 } 227 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) 228 if err != nil { 229 return err 230 } 231 increaseNonce(r.nonce) 232 r.cached = length 233 r.index = 0 234 return nil 235 } 236 237 func (r *Reader) ReadWithLength(length uint16) error { 238 end := int(length) + Overhead 239 _, err := io.ReadFull(r.upstream, r.buffer[:end]) 240 if err != nil { 241 return err 242 } 243 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) 244 if err != nil { 245 return err 246 } 247 increaseNonce(r.nonce) 248 r.cached = int(length) 249 r.index = 0 250 return nil 251 } 252 253 func (r *Reader) ReadExternalChunk(chunk []byte) error { 254 bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil) 255 if err != nil { 256 return err 257 } 258 increaseNonce(r.nonce) 259 r.cached = len(bb) 260 r.index = 0 261 return nil 262 } 263 264 func (r *Reader) ReadChunk(buffer *buf.Buffer, chunk []byte) error { 265 bb, err := r.cipher.Open(buffer.Index(buffer.Len()), r.nonce, chunk, nil) 266 if err != nil { 267 return err 268 } 269 increaseNonce(r.nonce) 270 buffer.Extend(len(bb)) 271 return nil 272 } 273 274 type Writer struct { 275 upstream io.Writer 276 cipher cipher.AEAD 277 maxPacketSize int 278 buffer []byte 279 nonce []byte 280 access sync.Mutex 281 } 282 283 func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Writer { 284 return &Writer{ 285 upstream: upstream, 286 cipher: cipher, 287 buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2), 288 nonce: make([]byte, cipher.NonceSize()), 289 maxPacketSize: maxPacketSize, 290 } 291 } 292 293 func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buffer []byte, nonce []byte) *Writer { 294 return &Writer{ 295 upstream: upstream, 296 cipher: cipher, 297 maxPacketSize: maxPacketSize, 298 buffer: buffer, 299 nonce: nonce, 300 } 301 } 302 303 func (w *Writer) Upstream() any { 304 return w.upstream 305 } 306 307 func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) { 308 for { 309 offset := Overhead + PacketLengthBufferSize 310 readN, readErr := r.Read(w.buffer[offset : offset+w.maxPacketSize]) 311 if readErr != nil { 312 return 0, readErr 313 } 314 binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(readN)) 315 w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil) 316 increaseNonce(w.nonce) 317 packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, w.buffer[offset:offset+readN], nil) 318 increaseNonce(w.nonce) 319 _, err = w.upstream.Write(w.buffer[:offset+len(packet)]) 320 if err != nil { 321 return 322 } 323 n += int64(readN) 324 } 325 } 326 327 func (w *Writer) Write(p []byte) (n int, err error) { 328 if len(p) == 0 { 329 return 330 } 331 332 for pLen := len(p); pLen > 0; { 333 var data []byte 334 if pLen > w.maxPacketSize { 335 data = p[:w.maxPacketSize] 336 p = p[w.maxPacketSize:] 337 pLen -= w.maxPacketSize 338 } else { 339 data = p 340 pLen = 0 341 } 342 w.access.Lock() 343 binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data))) 344 w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil) 345 increaseNonce(w.nonce) 346 offset := Overhead + PacketLengthBufferSize 347 packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil) 348 increaseNonce(w.nonce) 349 w.access.Unlock() 350 _, err = w.upstream.Write(w.buffer[:offset+len(packet)]) 351 if err != nil { 352 return 353 } 354 n += len(data) 355 } 356 357 return 358 } 359 360 func (w *Writer) WriteVectorised(buffers []*buf.Buffer) error { 361 defer buf.ReleaseMulti(buffers) 362 var index int 363 var err error 364 for _, buffer := range buffers { 365 pLen := buffer.Len() 366 if pLen > w.maxPacketSize { 367 _, err = w.Write(buffer.Bytes()) 368 if err != nil { 369 return err 370 } 371 } else { 372 if cap(w.buffer) < index+PacketLengthBufferSize+pLen+2*Overhead { 373 _, err = w.upstream.Write(w.buffer[:index]) 374 index = 0 375 if err != nil { 376 return err 377 } 378 } 379 w.access.Lock() 380 binary.BigEndian.PutUint16(w.buffer[index:index+PacketLengthBufferSize], uint16(pLen)) 381 w.cipher.Seal(w.buffer[index:index], w.nonce, w.buffer[index:index+PacketLengthBufferSize], nil) 382 increaseNonce(w.nonce) 383 offset := index + Overhead + PacketLengthBufferSize 384 w.cipher.Seal(w.buffer[offset:offset], w.nonce, buffer.Bytes(), nil) 385 increaseNonce(w.nonce) 386 w.access.Unlock() 387 index = offset + pLen + Overhead 388 } 389 } 390 if index > 0 { 391 _, err = w.upstream.Write(w.buffer[:index]) 392 } 393 return err 394 } 395 396 func (w *Writer) Buffer() *buf.Buffer { 397 return buf.With(w.buffer) 398 } 399 400 func (w *Writer) WriteChunk(buffer *buf.Buffer, chunk []byte) { 401 bb := w.cipher.Seal(buffer.Index(buffer.Len()), w.nonce, chunk, nil) 402 buffer.Extend(len(bb)) 403 increaseNonce(w.nonce) 404 } 405 406 func (w *Writer) BufferedWriter(reversed int) *BufferedWriter { 407 return &BufferedWriter{ 408 upstream: w, 409 reversed: reversed, 410 data: w.buffer[PacketLengthBufferSize+Overhead : len(w.buffer)-Overhead], 411 } 412 } 413 414 type BufferedWriter struct { 415 upstream *Writer 416 data []byte 417 reversed int 418 index int 419 } 420 421 func (w *BufferedWriter) Write(p []byte) (n int, err error) { 422 for { 423 cachedN := copy(w.data[w.reversed+w.index:], p[n:]) 424 w.index += cachedN 425 if cachedN == len(p[n:]) { 426 n += cachedN 427 return 428 } 429 err = w.Flush() 430 if err != nil { 431 return 432 } 433 n += cachedN 434 } 435 } 436 437 func (w *BufferedWriter) Flush() error { 438 if w.index == 0 { 439 if w.reversed > 0 { 440 _, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed]) 441 w.reversed = 0 442 return err 443 } 444 return nil 445 } 446 buffer := w.upstream.buffer[w.reversed:] 447 binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index)) 448 w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil) 449 increaseNonce(w.upstream.nonce) 450 offset := Overhead + PacketLengthBufferSize 451 packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil) 452 increaseNonce(w.upstream.nonce) 453 _, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)]) 454 w.reversed = 0 455 w.index = 0 456 return err 457 } 458 459 func increaseNonce(nonce []byte) { 460 for i := range nonce { 461 nonce[i]++ 462 if nonce[i] != 0 { 463 return 464 } 465 } 466 }