github.com/MerlinKodo/sing-shadowsocks@v0.2.6/shadowaead/aead.go (about) 1 package shadowaead 2 3 import ( 4 "crypto/cipher" 5 "encoding/binary" 6 "io" 7 "sync" 8 9 "github.com/sagernet/sing/common/buf" 10 ) 11 12 // https://shadowsocks.org/en/wiki/AEAD-Ciphers.html 13 const ( 14 MaxPacketSize = 16*1024 - 1 15 PacketLengthBufferSize = 2 16 ) 17 18 const ( 19 // Overhead 20 // crypto/cipher.gcmTagSize 21 // golang.org/x/crypto/chacha20poly1305.Overhead 22 // github.com/sina-ghaderi/poly1305.TagSize 23 // github.com/ericlagergren/siv.TagSize 24 // github.com/ericlagergren/aegis.TagSize128L 25 // github.com/ericlagergren/aegis.TagSize256 26 // github.com/Yawning/aez.aeadOverhead 27 // github.com/oasisprotocol/deoxysii.TagSize 28 Overhead = 16 29 ) 30 31 type Reader struct { 32 upstream io.Reader 33 cipher cipher.AEAD 34 buffer []byte 35 nonce []byte 36 index int 37 cached int 38 } 39 40 func NewReader(upstream io.Reader, cipher cipher.AEAD, maxPacketSize int) *Reader { 41 return &Reader{ 42 upstream: upstream, 43 cipher: cipher, 44 buffer: make([]byte, maxPacketSize+Overhead), 45 nonce: make([]byte, cipher.NonceSize()), 46 } 47 } 48 49 func NewRawReader(upstream io.Reader, cipher cipher.AEAD, buffer []byte, nonce []byte) *Reader { 50 return &Reader{ 51 upstream: upstream, 52 cipher: cipher, 53 buffer: buffer, 54 nonce: nonce, 55 } 56 } 57 58 func (r *Reader) Upstream() any { 59 return r.upstream 60 } 61 62 func (r *Reader) WriteTo(writer io.Writer) (n int64, err error) { 63 if r.cached > 0 { 64 writeN, writeErr := writer.Write(r.buffer[r.index : r.index+r.cached]) 65 if writeErr != nil { 66 return int64(writeN), writeErr 67 } 68 n += int64(writeN) 69 } 70 for { 71 start := PacketLengthBufferSize + Overhead 72 _, err = io.ReadFull(r.upstream, r.buffer[:start]) 73 if err != nil { 74 return 75 } 76 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) 77 if err != nil { 78 return 79 } 80 increaseNonce(r.nonce) 81 length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) 82 end := length + Overhead 83 _, err = io.ReadFull(r.upstream, r.buffer[:end]) 84 if err != nil { 85 return 86 } 87 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) 88 if err != nil { 89 return 90 } 91 increaseNonce(r.nonce) 92 writeN, writeErr := writer.Write(r.buffer[:length]) 93 if writeErr != nil { 94 return int64(writeN), writeErr 95 } 96 n += int64(writeN) 97 } 98 } 99 100 func (r *Reader) readInternal() (err error) { 101 start := PacketLengthBufferSize + Overhead 102 _, err = io.ReadFull(r.upstream, r.buffer[:start]) 103 if err != nil { 104 return err 105 } 106 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) 107 if err != nil { 108 return err 109 } 110 increaseNonce(r.nonce) 111 length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) 112 end := length + Overhead 113 _, err = io.ReadFull(r.upstream, r.buffer[:end]) 114 if err != nil { 115 return err 116 } 117 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) 118 if err != nil { 119 return err 120 } 121 increaseNonce(r.nonce) 122 r.cached = length 123 r.index = 0 124 return nil 125 } 126 127 func (r *Reader) ReadByte() (byte, error) { 128 if r.cached == 0 { 129 err := r.readInternal() 130 if err != nil { 131 return 0, err 132 } 133 } 134 index := r.index 135 r.index++ 136 r.cached-- 137 return r.buffer[index], nil 138 } 139 140 func (r *Reader) Read(b []byte) (n int, err error) { 141 if r.cached > 0 { 142 n = copy(b, r.buffer[r.index:r.index+r.cached]) 143 r.cached -= n 144 r.index += n 145 return 146 } 147 start := PacketLengthBufferSize + Overhead 148 _, err = io.ReadFull(r.upstream, r.buffer[:start]) 149 if err != nil { 150 return 0, err 151 } 152 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:start], nil) 153 if err != nil { 154 return 0, err 155 } 156 increaseNonce(r.nonce) 157 length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) 158 end := length + Overhead 159 160 if len(b) >= end { 161 data := b[:end] 162 _, err = io.ReadFull(r.upstream, data) 163 if err != nil { 164 return 0, err 165 } 166 _, err = r.cipher.Open(b[:0], r.nonce, data, nil) 167 if err != nil { 168 return 0, err 169 } 170 increaseNonce(r.nonce) 171 return length, nil 172 } else { 173 _, err = io.ReadFull(r.upstream, r.buffer[:end]) 174 if err != nil { 175 return 0, err 176 } 177 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) 178 if err != nil { 179 return 0, err 180 } 181 increaseNonce(r.nonce) 182 n = copy(b, r.buffer[:length]) 183 r.cached = length - n 184 r.index = n 185 return 186 } 187 } 188 189 func (r *Reader) Discard(n int) error { 190 for { 191 if r.cached >= n { 192 r.cached -= n 193 r.index += n 194 return nil 195 } else if r.cached > 0 { 196 n -= r.cached 197 r.cached = 0 198 r.index = 0 199 } 200 err := r.readInternal() 201 if err != nil { 202 return err 203 } 204 } 205 } 206 207 func (r *Reader) Buffer() *buf.Buffer { 208 buffer := buf.With(r.buffer) 209 buffer.Resize(r.index, r.cached) 210 return buffer 211 } 212 213 func (r *Reader) Cached() int { 214 return r.cached 215 } 216 217 func (r *Reader) CachedSlice() []byte { 218 return r.buffer[r.index : r.index+r.cached] 219 } 220 221 func (r *Reader) ReadWithLengthChunk(lengthChunk []byte) error { 222 _, err := r.cipher.Open(r.buffer[:0], r.nonce, lengthChunk, nil) 223 if err != nil { 224 return err 225 } 226 increaseNonce(r.nonce) 227 length := int(binary.BigEndian.Uint16(r.buffer[:PacketLengthBufferSize])) 228 end := length + Overhead 229 _, err = io.ReadFull(r.upstream, r.buffer[:end]) 230 if err != nil { 231 return err 232 } 233 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) 234 if err != nil { 235 return err 236 } 237 increaseNonce(r.nonce) 238 r.cached = length 239 r.index = 0 240 return nil 241 } 242 243 func (r *Reader) ReadWithLength(length uint16) error { 244 end := int(length) + Overhead 245 _, err := io.ReadFull(r.upstream, r.buffer[:end]) 246 if err != nil { 247 return err 248 } 249 _, err = r.cipher.Open(r.buffer[:0], r.nonce, r.buffer[:end], nil) 250 if err != nil { 251 return err 252 } 253 increaseNonce(r.nonce) 254 r.cached = int(length) 255 r.index = 0 256 return nil 257 } 258 259 func (r *Reader) ReadExternalChunk(chunk []byte) error { 260 bb, err := r.cipher.Open(r.buffer[:0], r.nonce, chunk, nil) 261 if err != nil { 262 return err 263 } 264 increaseNonce(r.nonce) 265 r.cached = len(bb) 266 r.index = 0 267 return nil 268 } 269 270 func (r *Reader) ReadChunk(buffer *buf.Buffer, chunk []byte) error { 271 bb, err := r.cipher.Open(buffer.Index(buffer.Len()), r.nonce, chunk, nil) 272 if err != nil { 273 return err 274 } 275 increaseNonce(r.nonce) 276 buffer.Extend(len(bb)) 277 return nil 278 } 279 280 type Writer struct { 281 upstream io.Writer 282 cipher cipher.AEAD 283 maxPacketSize int 284 buffer []byte 285 nonce []byte 286 access sync.Mutex 287 } 288 289 func NewWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int) *Writer { 290 return &Writer{ 291 upstream: upstream, 292 cipher: cipher, 293 buffer: make([]byte, maxPacketSize+PacketLengthBufferSize+Overhead*2), 294 nonce: make([]byte, cipher.NonceSize()), 295 maxPacketSize: maxPacketSize, 296 } 297 } 298 299 func NewRawWriter(upstream io.Writer, cipher cipher.AEAD, maxPacketSize int, buffer []byte, nonce []byte) *Writer { 300 return &Writer{ 301 upstream: upstream, 302 cipher: cipher, 303 maxPacketSize: maxPacketSize, 304 buffer: buffer, 305 nonce: nonce, 306 } 307 } 308 309 func (w *Writer) Upstream() any { 310 return w.upstream 311 } 312 313 func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) { 314 for { 315 offset := Overhead + PacketLengthBufferSize 316 readN, readErr := r.Read(w.buffer[offset : offset+w.maxPacketSize]) 317 if readErr != nil { 318 return 0, readErr 319 } 320 binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(readN)) 321 w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil) 322 increaseNonce(w.nonce) 323 packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, w.buffer[offset:offset+readN], nil) 324 increaseNonce(w.nonce) 325 _, err = w.upstream.Write(w.buffer[:offset+len(packet)]) 326 if err != nil { 327 return 328 } 329 n += int64(readN) 330 } 331 } 332 333 func (w *Writer) Write(p []byte) (n int, err error) { 334 if len(p) == 0 { 335 return 336 } 337 338 for pLen := len(p); pLen > 0; { 339 var data []byte 340 if pLen > w.maxPacketSize { 341 data = p[:w.maxPacketSize] 342 p = p[w.maxPacketSize:] 343 pLen -= w.maxPacketSize 344 } else { 345 data = p 346 pLen = 0 347 } 348 w.access.Lock() 349 binary.BigEndian.PutUint16(w.buffer[:PacketLengthBufferSize], uint16(len(data))) 350 w.cipher.Seal(w.buffer[:0], w.nonce, w.buffer[:PacketLengthBufferSize], nil) 351 increaseNonce(w.nonce) 352 offset := Overhead + PacketLengthBufferSize 353 packet := w.cipher.Seal(w.buffer[offset:offset], w.nonce, data, nil) 354 increaseNonce(w.nonce) 355 w.access.Unlock() 356 _, err = w.upstream.Write(w.buffer[:offset+len(packet)]) 357 if err != nil { 358 return 359 } 360 n += len(data) 361 } 362 363 return 364 } 365 366 func (w *Writer) WriteVectorised(buffers []*buf.Buffer) error { 367 defer buf.ReleaseMulti(buffers) 368 var index int 369 var err error 370 for _, buffer := range buffers { 371 pLen := buffer.Len() 372 if pLen > w.maxPacketSize { 373 _, err = w.Write(buffer.Bytes()) 374 if err != nil { 375 return err 376 } 377 } else { 378 if cap(w.buffer) < index+PacketLengthBufferSize+pLen+2*Overhead { 379 _, err = w.upstream.Write(w.buffer[:index]) 380 index = 0 381 if err != nil { 382 return err 383 } 384 } 385 w.access.Lock() 386 binary.BigEndian.PutUint16(w.buffer[index:index+PacketLengthBufferSize], uint16(pLen)) 387 w.cipher.Seal(w.buffer[index:index], w.nonce, w.buffer[index:index+PacketLengthBufferSize], nil) 388 increaseNonce(w.nonce) 389 offset := index + Overhead + PacketLengthBufferSize 390 w.cipher.Seal(w.buffer[offset:offset], w.nonce, buffer.Bytes(), nil) 391 increaseNonce(w.nonce) 392 w.access.Unlock() 393 index = offset + pLen + Overhead 394 } 395 } 396 if index > 0 { 397 _, err = w.upstream.Write(w.buffer[:index]) 398 } 399 return err 400 } 401 402 func (w *Writer) Buffer() *buf.Buffer { 403 return buf.With(w.buffer) 404 } 405 406 func (w *Writer) WriteChunk(buffer *buf.Buffer, chunk []byte) { 407 bb := w.cipher.Seal(buffer.Index(buffer.Len()), w.nonce, chunk, nil) 408 buffer.Extend(len(bb)) 409 increaseNonce(w.nonce) 410 } 411 412 func (w *Writer) BufferedWriter(reversed int) *BufferedWriter { 413 return &BufferedWriter{ 414 upstream: w, 415 reversed: reversed, 416 data: w.buffer[PacketLengthBufferSize+Overhead : len(w.buffer)-Overhead], 417 } 418 } 419 420 type BufferedWriter struct { 421 upstream *Writer 422 data []byte 423 reversed int 424 index int 425 } 426 427 func (w *BufferedWriter) Write(p []byte) (n int, err error) { 428 for { 429 cachedN := copy(w.data[w.reversed+w.index:], p[n:]) 430 w.index += cachedN 431 if cachedN == len(p[n:]) { 432 n += cachedN 433 return 434 } 435 err = w.Flush() 436 if err != nil { 437 return 438 } 439 n += cachedN 440 } 441 } 442 443 func (w *BufferedWriter) Flush() error { 444 if w.index == 0 { 445 if w.reversed > 0 { 446 _, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed]) 447 w.reversed = 0 448 return err 449 } 450 return nil 451 } 452 buffer := w.upstream.buffer[w.reversed:] 453 binary.BigEndian.PutUint16(buffer[:PacketLengthBufferSize], uint16(w.index)) 454 w.upstream.cipher.Seal(buffer[:0], w.upstream.nonce, buffer[:PacketLengthBufferSize], nil) 455 increaseNonce(w.upstream.nonce) 456 offset := Overhead + PacketLengthBufferSize 457 packet := w.upstream.cipher.Seal(buffer[offset:offset], w.upstream.nonce, buffer[offset:offset+w.index], nil) 458 increaseNonce(w.upstream.nonce) 459 _, err := w.upstream.upstream.Write(w.upstream.buffer[:w.reversed+offset+len(packet)]) 460 w.reversed = 0 461 w.index = 0 462 return err 463 } 464 465 func increaseNonce(nonce []byte) { 466 for i := range nonce { 467 nonce[i]++ 468 if nonce[i] != 0 { 469 return 470 } 471 } 472 }