github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/chat/signencrypt/seeker.go (about) 1 package signencrypt 2 3 import ( 4 "context" 5 "io" 6 7 "github.com/keybase/client/go/chat/globals" 8 "github.com/keybase/client/go/chat/utils" 9 "github.com/keybase/client/go/kbcrypto" 10 11 lru "github.com/hashicorp/golang-lru" 12 ) 13 14 // decodingReadSeeker provies an io.ReadSeeker interface to playing back a signencrypt'd payload 15 type decodingReadSeeker struct { 16 utils.DebugLabeler 17 18 ctx context.Context 19 source io.ReadSeeker 20 encKey SecretboxKey 21 verifyKey VerifyKey 22 sigPrefix kbcrypto.SignaturePrefix 23 nonce Nonce 24 size, offset int64 25 chunks *lru.Cache 26 } 27 28 var _ io.ReadSeeker = (*decodingReadSeeker)(nil) 29 30 func NewDecodingReadSeeker(ctx context.Context, g *globals.Context, source io.ReadSeeker, size int64, 31 encKey SecretboxKey, verifyKey VerifyKey, signaturePrefix kbcrypto.SignaturePrefix, nonce Nonce, 32 c *lru.Cache) io.ReadSeeker { 33 if c == nil { 34 // If the caller didn't give us a cache, then let's just make one 35 c, _ = lru.New(20) 36 } 37 return &decodingReadSeeker{ 38 DebugLabeler: utils.NewDebugLabeler(g.ExternalG(), "DecodingReadSeeker", true), 39 source: source, 40 size: size, 41 chunks: c, 42 encKey: encKey, 43 verifyKey: verifyKey, 44 sigPrefix: signaturePrefix, 45 nonce: nonce, 46 } 47 } 48 49 // getChunksFromCache returns the plaintext bytes for a set of chunks iff we have each chunk 50 // in our cache 51 func (r *decodingReadSeeker) getChunksFromCache(chunks []chunkSpec) (res []byte, ok bool) { 52 for _, c := range chunks { 53 if pt, ok := r.chunks.Get(c.index); ok { 54 res = append(res, pt.([]byte)...) 55 r.Debug(r.ctx, "getChunksFromCache: added: index: %d len: %v", c.index, len(pt.([]byte))) 56 } else { 57 r.Debug(r.ctx, "getChunksFromCache: missed: %v", c.index) 58 return res, false 59 } 60 } 61 return res, true 62 } 63 64 func (r *decodingReadSeeker) writeChunksToCache(pt []byte, chunks []chunkSpec) { 65 start := chunks[0].ptStart 66 for _, c := range chunks { 67 stored := make([]byte, len(pt[c.ptStart-start:c.ptEnd-start])) 68 // need to pull the specific chunk out of the plaintext bytes 69 copy(stored, pt[c.ptStart-start:c.ptEnd-start]) 70 r.Debug(r.ctx, "writeChunksToCache: adding index: %d len: %d", c.index, len(stored)) 71 r.chunks.Add(c.index, stored) 72 } 73 } 74 75 func (r *decodingReadSeeker) fetchChunks(chunks []chunkSpec) (res []byte, err error) { 76 // we want to fetch enough data for all the chunks in one hit on the source ReadSeeker 77 begin := chunks[0].cipherStart 78 end := chunks[len(chunks)-1].cipherEnd 79 num := end - begin 80 81 if _, err := r.source.Seek(begin, io.SeekStart); err != nil { 82 return res, err 83 } 84 var bufOffset int64 85 res = make([]byte, num) 86 for { 87 n, err := r.source.Read(res[bufOffset:]) 88 if err != nil { 89 return res, err 90 } 91 bufOffset += int64(n) 92 if bufOffset >= num { 93 break 94 } 95 } 96 return res, nil 97 } 98 99 func (r *decodingReadSeeker) clamp(offset int64) int64 { 100 if offset >= r.size { 101 offset = r.size 102 } 103 return offset 104 } 105 106 func (r *decodingReadSeeker) extractPlaintext(plainText []byte, num int64, chunks []chunkSpec) []byte { 107 datBegin := chunks[0].ptStart 108 ptBegin := r.offset 109 ptEnd := r.clamp(r.offset + num) 110 r.Debug(r.ctx, "extractPlaintext: datBegin: %v ptBegin: %v ptEnd: %v", datBegin, ptBegin, ptEnd) 111 return plainText[ptBegin-datBegin : ptEnd-datBegin] 112 } 113 114 // getReadaheadFactor gives the number of chunks we should read at minimum from the source. For larger 115 // files we try to read more so we don't make too many underlying requests. 116 func (r *decodingReadSeeker) getReadaheadFactor() int64 { 117 mb := int64(1 << 20) 118 switch { 119 case r.size >= 1000*mb: 120 return 16 121 case r.size >= 500*mb: 122 return 8 123 default: 124 return 4 125 } 126 } 127 128 func (r *decodingReadSeeker) Read(res []byte) (n int, err error) { 129 defer r.Trace(r.ctx, &err, "Read(%v,%v)", r.offset, len(res))() 130 if r.offset >= r.size { 131 return 0, io.EOF 132 } 133 num := int64(len(res)) 134 chunkEnd := r.clamp(r.offset + num) 135 r.Debug(r.ctx, "Read: chunkEnd: %v", chunkEnd) 136 chunks := getChunksInRange(r.offset, chunkEnd, r.size) 137 var chunkPlaintext []byte 138 139 // Check for a full hit on all the chunks first 140 var ok bool 141 if chunkPlaintext, ok = r.getChunksFromCache(chunks); !ok { 142 // if we miss, then we need to fetch the data from our underlying source. Given that this 143 // source is usually on the network, then fetch at least K chunks so we aren't making 144 // too many requests. 145 minChunkEnd := r.clamp(r.offset + r.getReadaheadFactor()*DefaultPlaintextChunkLength) 146 if minChunkEnd > chunkEnd { 147 chunkEnd = minChunkEnd 148 } 149 prefetchChunks := getChunksInRange(r.offset, chunkEnd, r.size) 150 cipherText, err := r.fetchChunks(prefetchChunks) 151 if err != nil { 152 return n, err 153 } 154 for _, c := range prefetchChunks { 155 r.Debug(r.ctx, "Read: chunk: index: %v ptstart: %v ptend: %v cstart: %v cend: %v", c.index, 156 c.ptStart, c.ptEnd, c.cipherStart, c.cipherEnd) 157 } 158 // Decrypt all the chunks and write out to the cache 159 decoder := NewDecoder(r.encKey, r.verifyKey, r.sigPrefix, r.nonce) 160 decoder.setChunkNum(uint64(prefetchChunks[0].index)) 161 if chunkPlaintext, err = decoder.Write(cipherText); err != nil { 162 return n, err 163 } 164 // We might have some straggling data, so just hit Finish here to potentially pick it up. If it 165 // returns an error, then we just ignore it. 166 if finishPlaintext, err := decoder.Finish(); err == nil { 167 chunkPlaintext = append(chunkPlaintext, finishPlaintext...) 168 } 169 r.writeChunksToCache(chunkPlaintext, prefetchChunks) 170 } 171 172 r.Debug(r.ctx, "Read: len(chunkPlainText): %v", len(chunkPlaintext)) 173 plainText := r.extractPlaintext(chunkPlaintext, num, chunks) 174 copy(res, plainText) 175 numRead := int64(len(plainText)) 176 r.Debug(r.ctx, "Read: len(pt): %v", len(plainText)) 177 r.offset += numRead 178 return int(numRead), nil 179 } 180 181 func (r *decodingReadSeeker) Seek(offset int64, whence int) (res int64, err error) { 182 defer r.Trace(r.ctx, &err, "Seek(%v,%v)", offset, whence)() 183 switch whence { 184 case io.SeekStart: 185 r.offset = offset 186 case io.SeekCurrent: 187 r.offset += offset 188 case io.SeekEnd: 189 r.offset = r.size - offset 190 } 191 return r.offset, nil 192 }