github.com/sagernet/sing-shadowsocks2@v0.2.0/internal/shadowio/reader.go (about) 1 package shadowio 2 3 import ( 4 "crypto/cipher" 5 "encoding/binary" 6 "io" 7 8 "github.com/sagernet/sing/common/buf" 9 N "github.com/sagernet/sing/common/network" 10 ) 11 12 const PacketLengthBufferSize = 2 13 14 const ( 15 // Overhead 16 // crypto/cipher.gcmTagSize 17 // golang.org/x/crypto/chacha20poly1305.Overhead 18 Overhead = 16 19 ) 20 21 var ( 22 _ N.ExtendedReader = (*Reader)(nil) 23 _ N.ReadWaiter = (*Reader)(nil) 24 ) 25 26 type Reader struct { 27 reader io.Reader 28 cipher cipher.AEAD 29 nonce []byte 30 cache *buf.Buffer 31 readWaitOptions N.ReadWaitOptions 32 } 33 34 func NewReader(upstream io.Reader, cipher cipher.AEAD) *Reader { 35 return &Reader{ 36 reader: upstream, 37 cipher: cipher, 38 nonce: make([]byte, cipher.NonceSize()), 39 } 40 } 41 42 func (r *Reader) ReadFixedBuffer(pLen int) (*buf.Buffer, error) { 43 buffer := buf.NewSize(pLen + Overhead) 44 _, err := buffer.ReadFullFrom(r.reader, buffer.FreeLen()) 45 if err != nil { 46 buffer.Release() 47 return nil, err 48 } 49 err = r.Decrypt(buffer.Index(0), buffer.Bytes()) 50 if err != nil { 51 buffer.Release() 52 return nil, err 53 } 54 buffer.Truncate(pLen) 55 r.cache = buffer 56 return buffer, nil 57 } 58 59 func (r *Reader) Decrypt(destination []byte, source []byte) error { 60 _, err := r.cipher.Open(destination[:0], r.nonce, source, nil) 61 if err != nil { 62 return err 63 } 64 increaseNonce(r.nonce) 65 return nil 66 } 67 68 func (r *Reader) Read(p []byte) (n int, err error) { 69 for { 70 if r.cache != nil { 71 if r.cache.IsEmpty() { 72 r.cache.Release() 73 r.cache = nil 74 } else { 75 n = copy(p, r.cache.Bytes()) 76 if n > 0 { 77 r.cache.Advance(n) 78 return 79 } 80 } 81 } 82 r.cache, err = r.readBuffer() 83 if err != nil { 84 return 85 } 86 } 87 } 88 89 func (r *Reader) ReadBuffer(buffer *buf.Buffer) error { 90 var err error 91 for { 92 if r.cache != nil { 93 if r.cache.IsEmpty() { 94 r.cache.Release() 95 r.cache = nil 96 } else { 97 n := copy(buffer.FreeBytes(), r.cache.Bytes()) 98 if n > 0 { 99 buffer.Truncate(n) 100 r.cache.Advance(n) 101 return nil 102 } 103 } 104 } 105 r.cache, err = r.readBuffer() 106 if err != nil { 107 return err 108 } 109 } 110 } 111 112 func (r *Reader) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { 113 r.readWaitOptions = options 114 return options.NeedHeadroom() 115 } 116 117 func (r *Reader) WaitReadBuffer() (buffer *buf.Buffer, err error) { 118 if r.readWaitOptions.NeedHeadroom() { 119 for { 120 if r.cache != nil { 121 if r.cache.IsEmpty() { 122 r.cache.Release() 123 r.cache = nil 124 } else { 125 buffer = r.readWaitOptions.NewBuffer() 126 var n int 127 n, err = buffer.Write(r.cache.Bytes()) 128 if err != nil { 129 buffer.Release() 130 return 131 } 132 buffer.Truncate(n) 133 r.cache.Advance(n) 134 r.readWaitOptions.PostReturn(buffer) 135 return 136 } 137 } 138 r.cache, err = r.readBuffer() 139 if err != nil { 140 return 141 } 142 } 143 } else { 144 cache := r.cache 145 if cache != nil { 146 r.cache = nil 147 return cache, nil 148 } 149 return r.readBuffer() 150 } 151 } 152 153 func (r *Reader) readBuffer() (*buf.Buffer, error) { 154 buffer := buf.NewSize(PacketLengthBufferSize + Overhead) 155 _, err := buffer.ReadFullFrom(r.reader, buffer.FreeLen()) 156 if err != nil { 157 buffer.Release() 158 return nil, err 159 } 160 _, err = r.cipher.Open(buffer.Index(0), r.nonce, buffer.Bytes(), nil) 161 if err != nil { 162 buffer.Release() 163 return nil, err 164 } 165 increaseNonce(r.nonce) 166 length := int(binary.BigEndian.Uint16(buffer.To(PacketLengthBufferSize))) 167 buffer.Release() 168 buffer = buf.NewSize(length + Overhead) 169 _, err = buffer.ReadFullFrom(r.reader, buffer.FreeLen()) 170 if err != nil { 171 buffer.Release() 172 return nil, err 173 } 174 _, err = r.cipher.Open(buffer.Index(0), r.nonce, buffer.Bytes(), nil) 175 if err != nil { 176 buffer.Release() 177 return nil, err 178 } 179 increaseNonce(r.nonce) 180 buffer.Truncate(length) 181 return buffer, nil 182 }