github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/chat/signencrypt/seeker.go (about)

     1  package signencrypt
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  
     7  	"github.com/keybase/client/go/chat/globals"
     8  	"github.com/keybase/client/go/chat/utils"
     9  	"github.com/keybase/client/go/kbcrypto"
    10  
    11  	lru "github.com/hashicorp/golang-lru"
    12  )
    13  
    14  // decodingReadSeeker provies an io.ReadSeeker interface to playing back a signencrypt'd payload
    15  type decodingReadSeeker struct {
    16  	utils.DebugLabeler
    17  
    18  	ctx          context.Context
    19  	source       io.ReadSeeker
    20  	encKey       SecretboxKey
    21  	verifyKey    VerifyKey
    22  	sigPrefix    kbcrypto.SignaturePrefix
    23  	nonce        Nonce
    24  	size, offset int64
    25  	chunks       *lru.Cache
    26  }
    27  
    28  var _ io.ReadSeeker = (*decodingReadSeeker)(nil)
    29  
    30  func NewDecodingReadSeeker(ctx context.Context, g *globals.Context, source io.ReadSeeker, size int64,
    31  	encKey SecretboxKey, verifyKey VerifyKey, signaturePrefix kbcrypto.SignaturePrefix, nonce Nonce,
    32  	c *lru.Cache) io.ReadSeeker {
    33  	if c == nil {
    34  		// If the caller didn't give us a cache, then let's just make one
    35  		c, _ = lru.New(20)
    36  	}
    37  	return &decodingReadSeeker{
    38  		DebugLabeler: utils.NewDebugLabeler(g.ExternalG(), "DecodingReadSeeker", true),
    39  		source:       source,
    40  		size:         size,
    41  		chunks:       c,
    42  		encKey:       encKey,
    43  		verifyKey:    verifyKey,
    44  		sigPrefix:    signaturePrefix,
    45  		nonce:        nonce,
    46  	}
    47  }
    48  
    49  // getChunksFromCache returns the plaintext bytes for a set of chunks iff we have each chunk
    50  // in our cache
    51  func (r *decodingReadSeeker) getChunksFromCache(chunks []chunkSpec) (res []byte, ok bool) {
    52  	for _, c := range chunks {
    53  		if pt, ok := r.chunks.Get(c.index); ok {
    54  			res = append(res, pt.([]byte)...)
    55  			r.Debug(r.ctx, "getChunksFromCache: added: index: %d len: %v", c.index, len(pt.([]byte)))
    56  		} else {
    57  			r.Debug(r.ctx, "getChunksFromCache: missed: %v", c.index)
    58  			return res, false
    59  		}
    60  	}
    61  	return res, true
    62  }
    63  
    64  func (r *decodingReadSeeker) writeChunksToCache(pt []byte, chunks []chunkSpec) {
    65  	start := chunks[0].ptStart
    66  	for _, c := range chunks {
    67  		stored := make([]byte, len(pt[c.ptStart-start:c.ptEnd-start]))
    68  		// need to pull the specific chunk out of the plaintext bytes
    69  		copy(stored, pt[c.ptStart-start:c.ptEnd-start])
    70  		r.Debug(r.ctx, "writeChunksToCache: adding index: %d len: %d", c.index, len(stored))
    71  		r.chunks.Add(c.index, stored)
    72  	}
    73  }
    74  
    75  func (r *decodingReadSeeker) fetchChunks(chunks []chunkSpec) (res []byte, err error) {
    76  	// we want to fetch enough data for all the chunks in one hit on the source ReadSeeker
    77  	begin := chunks[0].cipherStart
    78  	end := chunks[len(chunks)-1].cipherEnd
    79  	num := end - begin
    80  
    81  	if _, err := r.source.Seek(begin, io.SeekStart); err != nil {
    82  		return res, err
    83  	}
    84  	var bufOffset int64
    85  	res = make([]byte, num)
    86  	for {
    87  		n, err := r.source.Read(res[bufOffset:])
    88  		if err != nil {
    89  			return res, err
    90  		}
    91  		bufOffset += int64(n)
    92  		if bufOffset >= num {
    93  			break
    94  		}
    95  	}
    96  	return res, nil
    97  }
    98  
    99  func (r *decodingReadSeeker) clamp(offset int64) int64 {
   100  	if offset >= r.size {
   101  		offset = r.size
   102  	}
   103  	return offset
   104  }
   105  
   106  func (r *decodingReadSeeker) extractPlaintext(plainText []byte, num int64, chunks []chunkSpec) []byte {
   107  	datBegin := chunks[0].ptStart
   108  	ptBegin := r.offset
   109  	ptEnd := r.clamp(r.offset + num)
   110  	r.Debug(r.ctx, "extractPlaintext: datBegin: %v ptBegin: %v ptEnd: %v", datBegin, ptBegin, ptEnd)
   111  	return plainText[ptBegin-datBegin : ptEnd-datBegin]
   112  }
   113  
   114  // getReadaheadFactor gives the number of chunks we should read at minimum from the source. For larger
   115  // files we try to read more so we don't make too many underlying requests.
   116  func (r *decodingReadSeeker) getReadaheadFactor() int64 {
   117  	mb := int64(1 << 20)
   118  	switch {
   119  	case r.size >= 1000*mb:
   120  		return 16
   121  	case r.size >= 500*mb:
   122  		return 8
   123  	default:
   124  		return 4
   125  	}
   126  }
   127  
   128  func (r *decodingReadSeeker) Read(res []byte) (n int, err error) {
   129  	defer r.Trace(r.ctx, &err, "Read(%v,%v)", r.offset, len(res))()
   130  	if r.offset >= r.size {
   131  		return 0, io.EOF
   132  	}
   133  	num := int64(len(res))
   134  	chunkEnd := r.clamp(r.offset + num)
   135  	r.Debug(r.ctx, "Read: chunkEnd: %v", chunkEnd)
   136  	chunks := getChunksInRange(r.offset, chunkEnd, r.size)
   137  	var chunkPlaintext []byte
   138  
   139  	// Check for a full hit on all the chunks first
   140  	var ok bool
   141  	if chunkPlaintext, ok = r.getChunksFromCache(chunks); !ok {
   142  		// if we miss, then we need to fetch the data from our underlying source. Given that this
   143  		// source is usually on the network, then fetch at least K chunks so we aren't making
   144  		// too many requests.
   145  		minChunkEnd := r.clamp(r.offset + r.getReadaheadFactor()*DefaultPlaintextChunkLength)
   146  		if minChunkEnd > chunkEnd {
   147  			chunkEnd = minChunkEnd
   148  		}
   149  		prefetchChunks := getChunksInRange(r.offset, chunkEnd, r.size)
   150  		cipherText, err := r.fetchChunks(prefetchChunks)
   151  		if err != nil {
   152  			return n, err
   153  		}
   154  		for _, c := range prefetchChunks {
   155  			r.Debug(r.ctx, "Read: chunk: index: %v ptstart: %v ptend: %v cstart: %v cend: %v", c.index,
   156  				c.ptStart, c.ptEnd, c.cipherStart, c.cipherEnd)
   157  		}
   158  		// Decrypt all the chunks and write out to the cache
   159  		decoder := NewDecoder(r.encKey, r.verifyKey, r.sigPrefix, r.nonce)
   160  		decoder.setChunkNum(uint64(prefetchChunks[0].index))
   161  		if chunkPlaintext, err = decoder.Write(cipherText); err != nil {
   162  			return n, err
   163  		}
   164  		// We might have some straggling data, so just hit Finish here to potentially pick it up. If it
   165  		// returns an error, then we just ignore it.
   166  		if finishPlaintext, err := decoder.Finish(); err == nil {
   167  			chunkPlaintext = append(chunkPlaintext, finishPlaintext...)
   168  		}
   169  		r.writeChunksToCache(chunkPlaintext, prefetchChunks)
   170  	}
   171  
   172  	r.Debug(r.ctx, "Read: len(chunkPlainText): %v", len(chunkPlaintext))
   173  	plainText := r.extractPlaintext(chunkPlaintext, num, chunks)
   174  	copy(res, plainText)
   175  	numRead := int64(len(plainText))
   176  	r.Debug(r.ctx, "Read: len(pt): %v", len(plainText))
   177  	r.offset += numRead
   178  	return int(numRead), nil
   179  }
   180  
   181  func (r *decodingReadSeeker) Seek(offset int64, whence int) (res int64, err error) {
   182  	defer r.Trace(r.ctx, &err, "Seek(%v,%v)", offset, whence)()
   183  	switch whence {
   184  	case io.SeekStart:
   185  		r.offset = offset
   186  	case io.SeekCurrent:
   187  		r.offset += offset
   188  	case io.SeekEnd:
   189  		r.offset = r.size - offset
   190  	}
   191  	return r.offset, nil
   192  }