github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/chat/attachments/store.go (about)

     1  package attachments
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/hmac"
     6  	"crypto/sha256"
     7  	"errors"
     8  	"fmt"
     9  	"hash"
    10  	"io"
    11  	"os"
    12  	"path/filepath"
    13  	"sync"
    14  
    15  	"github.com/keybase/client/go/chat/attachments/progress"
    16  	"github.com/keybase/client/go/chat/globals"
    17  	"github.com/keybase/client/go/chat/types"
    18  	"github.com/keybase/client/go/kbcrypto"
    19  	"github.com/keybase/client/go/libkb"
    20  	"github.com/keybase/go-crypto/ed25519"
    21  	"github.com/keybase/go-framed-msgpack-rpc/rpc"
    22  
    23  	lru "github.com/hashicorp/golang-lru"
    24  	"github.com/keybase/client/go/chat/s3"
    25  	"github.com/keybase/client/go/chat/signencrypt"
    26  	"github.com/keybase/client/go/chat/utils"
    27  	"github.com/keybase/client/go/protocol/chat1"
    28  	"github.com/keybase/client/go/protocol/gregor1"
    29  	"golang.org/x/net/context"
    30  )
    31  
    32  type ReadResetter interface {
    33  	io.Reader
    34  	Reset() error
    35  }
    36  
    37  type UploadTask struct {
    38  	S3Params       chat1.S3Params
    39  	Filename       string
    40  	FileSize       int64
    41  	Plaintext      ReadResetter
    42  	taskHash       []byte
    43  	S3Signer       s3.Signer
    44  	ConversationID chat1.ConversationID
    45  	UserID         gregor1.UID
    46  	OutboxID       chat1.OutboxID
    47  	Preview        bool
    48  	Progress       types.ProgressReporter
    49  }
    50  
    51  func (u *UploadTask) computeHash() {
    52  	hasher := sha256.New()
    53  	seed := fmt.Sprintf("%s:%v", u.OutboxID, u.Preview)
    54  	_, _ = io.Copy(hasher, bytes.NewReader([]byte(seed)))
    55  	u.taskHash = hasher.Sum(nil)
    56  }
    57  
    58  func (u *UploadTask) hash() []byte {
    59  	return u.taskHash
    60  }
    61  
    62  func (u *UploadTask) stashKey() StashKey {
    63  	return NewStashKey(u.OutboxID, u.Preview)
    64  }
    65  
    66  func (u *UploadTask) Nonce() signencrypt.Nonce {
    67  	var n [signencrypt.NonceSize]byte
    68  	copy(n[:], u.taskHash)
    69  	return &n
    70  }
    71  
    72  type Store interface {
    73  	UploadAsset(ctx context.Context, task *UploadTask, encryptedOut io.Writer) (chat1.Asset, error)
    74  	DownloadAsset(ctx context.Context, params chat1.S3Params, asset chat1.Asset, w io.Writer,
    75  		signer s3.Signer, progress types.ProgressReporter) error
    76  	GetAssetReader(ctx context.Context, params chat1.S3Params, asset chat1.Asset,
    77  		signer s3.Signer) (io.ReadCloser, error)
    78  	StreamAsset(ctx context.Context, params chat1.S3Params, asset chat1.Asset, signer s3.Signer) (io.ReadSeeker, error)
    79  	DecryptAsset(ctx context.Context, w io.Writer, body io.Reader, asset chat1.Asset,
    80  		progress types.ProgressReporter) error
    81  	DeleteAsset(ctx context.Context, params chat1.S3Params, signer s3.Signer, asset chat1.Asset) error
    82  	DeleteAssets(ctx context.Context, params chat1.S3Params, signer s3.Signer, assets []chat1.Asset) error
    83  }
    84  
    85  type streamCache struct {
    86  	path  string
    87  	cache *lru.Cache
    88  }
    89  
    90  type S3Store struct {
    91  	globals.Contextified
    92  	utils.DebugLabeler
    93  
    94  	s3c   s3.Root
    95  	stash AttachmentStash
    96  
    97  	scMutex     sync.Mutex
    98  	streamCache *streamCache
    99  
   100  	// testing hooks
   101  	testing    bool                        // true if we're in a test
   102  	keyTester  func(encKey, sigKey []byte) // used for testing only to check key changes
   103  	aborts     int                         // number of aborts
   104  	blockLimit int                         // max number of blocks to upload
   105  }
   106  
   107  // NewS3Store creates a standard Store that uses a real
   108  // S3 connection.
   109  func NewS3Store(g *globals.Context, runtimeDir string) *S3Store {
   110  	return &S3Store{
   111  		Contextified: globals.NewContextified(g),
   112  		DebugLabeler: utils.NewDebugLabeler(g.ExternalG(), "Attachments.Store", false),
   113  		s3c:          &s3.AWS{},
   114  		stash:        NewFileStash(runtimeDir),
   115  	}
   116  }
   117  
   118  // NewStoreTesting creates an Store suitable for testing
   119  // purposes.  It is not exposed outside this package.
   120  // It uses an in-memory s3 interface, reports enc/sig keys, and allows limiting
   121  // the number of blocks uploaded.
   122  func NewStoreTesting(g *globals.Context, kt func(enc, sig []byte)) *S3Store {
   123  	return &S3Store{
   124  		Contextified: globals.NewContextified(g),
   125  		DebugLabeler: utils.NewDebugLabeler(g.ExternalG(), "Attachments.Store", false),
   126  		s3c:          &s3.Mem{},
   127  		stash:        NewFileStash(os.TempDir()),
   128  		keyTester:    kt,
   129  		testing:      true,
   130  	}
   131  }
   132  
   133  func (a *S3Store) UploadAsset(ctx context.Context, task *UploadTask, encryptedOut io.Writer) (res chat1.Asset, err error) {
   134  	defer a.Trace(ctx, &err, "UploadAsset")()
   135  	// compute plaintext hash
   136  	if task.hash() == nil {
   137  		task.computeHash()
   138  	} else {
   139  		if !a.testing {
   140  			return res, errors.New("task.plaintextHash not nil")
   141  		}
   142  		a.Debug(ctx, "UploadAsset: skipping plaintextHash calculation due to existing plaintextHash (testing only feature)")
   143  	}
   144  
   145  	// encrypt the stream
   146  	enc := NewSignEncrypter()
   147  	len := enc.EncryptedLen(task.FileSize)
   148  
   149  	// check for previous interrupted upload attempt
   150  	var previous *AttachmentInfo
   151  	resumable := len > minMultiSize // can only resume multi uploads
   152  	if resumable {
   153  		previous = a.previousUpload(ctx, task)
   154  	}
   155  
   156  	res, err = a.uploadAsset(ctx, task, enc, previous, resumable, encryptedOut)
   157  
   158  	// if the upload is aborted, reset the stream and start over to get new keys
   159  	if err == ErrAbortOnPartMismatch && previous != nil {
   160  		a.Debug(ctx, "UploadAsset: resume call aborted, resetting stream and starting from scratch")
   161  		a.aborts++
   162  		err := task.Plaintext.Reset()
   163  		if err != nil {
   164  			a.Debug(ctx, "UploadAsset: reset failed: %+v", err)
   165  		}
   166  		task.computeHash()
   167  		return a.uploadAsset(ctx, task, enc, nil, resumable, encryptedOut)
   168  	}
   169  
   170  	return res, err
   171  }
   172  
   173  func (a *S3Store) uploadAsset(ctx context.Context, task *UploadTask, enc *SignEncrypter,
   174  	previous *AttachmentInfo, resumable bool, encryptedOut io.Writer) (asset chat1.Asset, err error) {
   175  	defer a.Trace(ctx, &err, "uploadAsset")()
   176  	var encReader io.Reader
   177  	var ptHash hash.Hash
   178  	if previous != nil {
   179  		a.Debug(ctx, "uploadAsset: found previous upload for %s in conv %s", task.Filename,
   180  			task.ConversationID)
   181  		encReader, err = enc.EncryptResume(task.Plaintext, task.Nonce(), previous.EncKey, previous.SignKey,
   182  			previous.VerifyKey)
   183  		if err != nil {
   184  			return chat1.Asset{}, err
   185  		}
   186  	} else {
   187  		ptHash = sha256.New()
   188  		tee := io.TeeReader(task.Plaintext, ptHash)
   189  		encReader, err = enc.EncryptWithNonce(tee, task.Nonce())
   190  		if err != nil {
   191  			return chat1.Asset{}, err
   192  		}
   193  		if resumable {
   194  			a.startUpload(ctx, task, enc)
   195  		}
   196  	}
   197  
   198  	if a.testing && a.keyTester != nil {
   199  		a.Debug(ctx, "uploadAsset: Store.keyTester exists, reporting keys")
   200  		a.keyTester(enc.EncryptKey(), enc.VerifyKey())
   201  	}
   202  
   203  	// compute ciphertext hash
   204  	hash := sha256.New()
   205  	tee := io.TeeReader(io.TeeReader(encReader, hash), encryptedOut)
   206  
   207  	// post to s3
   208  	length := enc.EncryptedLen(task.FileSize)
   209  	record := rpc.NewNetworkInstrumenter(a.G().RemoteNetworkInstrumenterStorage, "ChatAttachmentUpload")
   210  	defer func() { _ = record.RecordAndFinish(ctx, length) }()
   211  	upRes, err := a.PutS3(ctx, tee, length, task, previous)
   212  	if err != nil {
   213  		if err == ErrAbortOnPartMismatch && previous != nil {
   214  			// erase information about previous upload attempt
   215  			a.finishUpload(ctx, task)
   216  		}
   217  		ew, ok := err.(*ErrorWrapper)
   218  		if ok {
   219  			a.Debug(ctx, "uploadAsset: PutS3 error details: %s", ew.Details())
   220  		}
   221  		return chat1.Asset{}, err
   222  	}
   223  	a.Debug(ctx, "uploadAsset: chat attachment upload: %+v", upRes)
   224  
   225  	asset = chat1.Asset{
   226  		Filename:  filepath.Base(task.Filename),
   227  		Region:    upRes.Region,
   228  		Endpoint:  upRes.Endpoint,
   229  		Bucket:    upRes.Bucket,
   230  		Path:      upRes.Path,
   231  		Size:      upRes.Size,
   232  		Key:       enc.EncryptKey(),
   233  		VerifyKey: enc.VerifyKey(),
   234  		EncHash:   hash.Sum(nil),
   235  		Nonce:     task.Nonce()[:],
   236  	}
   237  	if ptHash != nil {
   238  		// can only get this in the non-resume case
   239  		asset.PtHash = ptHash.Sum(nil)
   240  	}
   241  	if resumable {
   242  		a.finishUpload(ctx, task)
   243  	}
   244  	return asset, nil
   245  }
   246  
   247  func (a *S3Store) getAssetBucket(asset chat1.Asset, params chat1.S3Params, signer s3.Signer) s3.BucketInt {
   248  	region := a.regionFromAsset(asset)
   249  	return a.s3Conn(signer, region, params.AccessKey, params.Token).Bucket(asset.Bucket)
   250  }
   251  
   252  func (a *S3Store) GetAssetReader(ctx context.Context, params chat1.S3Params, asset chat1.Asset,
   253  	signer s3.Signer) (io.ReadCloser, error) {
   254  	b := a.getAssetBucket(asset, params, signer)
   255  	return b.GetReader(ctx, asset.Path)
   256  }
   257  
   258  func (a *S3Store) DecryptAsset(ctx context.Context, w io.Writer, body io.Reader, asset chat1.Asset,
   259  	progressReporter types.ProgressReporter) error {
   260  	// compute hash
   261  	hash := sha256.New()
   262  	verify := io.TeeReader(body, hash)
   263  
   264  	// to keep track of download progress
   265  	progWriter := progress.NewProgressWriter(progressReporter, asset.Size)
   266  	tee := io.TeeReader(verify, progWriter)
   267  
   268  	// decrypt body
   269  	dec := NewSignDecrypter()
   270  	var decBody io.Reader
   271  	if asset.Nonce != nil {
   272  		var nonce [signencrypt.NonceSize]byte
   273  		copy(nonce[:], asset.Nonce)
   274  		decBody = dec.DecryptWithNonce(tee, &nonce, asset.Key, asset.VerifyKey)
   275  	} else {
   276  		decBody = dec.Decrypt(tee, asset.Key, asset.VerifyKey)
   277  	}
   278  
   279  	ptHash := sha256.New()
   280  	tee = io.TeeReader(decBody, ptHash)
   281  	n, err := io.Copy(w, tee)
   282  	if err != nil {
   283  		return err
   284  	}
   285  
   286  	a.Debug(ctx, "DecryptAsset: downloaded and decrypted to %d plaintext bytes", n)
   287  	progWriter.Finish()
   288  
   289  	// validate the EncHash
   290  	if !hmac.Equal(asset.EncHash, hash.Sum(nil)) {
   291  		return fmt.Errorf("invalid attachment content hash")
   292  	}
   293  	// validate pt hash if we have it
   294  	if asset.PtHash != nil && !hmac.Equal(asset.PtHash, ptHash.Sum(nil)) {
   295  		return fmt.Errorf("invalid attachment plaintext hash")
   296  	}
   297  	a.Debug(ctx, "DecryptAsset: attachment content hash is valid")
   298  	return nil
   299  }
   300  
   301  // DownloadAsset gets an object from S3 as described in asset.
   302  func (a *S3Store) DownloadAsset(ctx context.Context, params chat1.S3Params, asset chat1.Asset,
   303  	w io.Writer, signer s3.Signer, progress types.ProgressReporter) error {
   304  	if asset.Key == nil || asset.VerifyKey == nil || asset.EncHash == nil {
   305  		return fmt.Errorf("unencrypted attachments not supported: asset: %#v", asset)
   306  	}
   307  	body, err := a.GetAssetReader(ctx, params, asset, signer)
   308  	defer func() {
   309  		if body != nil {
   310  			body.Close()
   311  		}
   312  	}()
   313  	if err != nil {
   314  		return err
   315  	}
   316  	a.Debug(ctx, "DownloadAsset: downloading %s from s3", asset.Path)
   317  	return a.DecryptAsset(ctx, w, body, asset, progress)
   318  }
   319  
   320  type s3Seeker struct {
   321  	utils.DebugLabeler
   322  	ctx    context.Context
   323  	asset  chat1.Asset
   324  	bucket s3.BucketInt
   325  	offset int64
   326  }
   327  
   328  func newS3Seeker(ctx context.Context, g *globals.Context, asset chat1.Asset, bucket s3.BucketInt) *s3Seeker {
   329  	return &s3Seeker{
   330  		DebugLabeler: utils.NewDebugLabeler(g.ExternalG(), "s3Seeker", false),
   331  		ctx:          ctx,
   332  		asset:        asset,
   333  		bucket:       bucket,
   334  	}
   335  }
   336  
   337  func (s *s3Seeker) Read(b []byte) (n int, err error) {
   338  	defer s.Trace(s.ctx, &err, "Read(%v,%v)", s.offset, len(b))()
   339  	if s.offset >= s.asset.Size {
   340  		return 0, io.EOF
   341  	}
   342  	rc, err := s.bucket.GetReaderWithRange(s.ctx, s.asset.Path, s.offset, s.offset+int64(len(b)))
   343  	if err != nil {
   344  		return 0, err
   345  	}
   346  	defer rc.Close()
   347  	var buf bytes.Buffer
   348  	if _, err := io.Copy(&buf, rc); err != nil {
   349  		return 0, err
   350  	}
   351  	copy(b, buf.Bytes())
   352  	return len(b), nil
   353  }
   354  
   355  func (s *s3Seeker) Seek(offset int64, whence int) (res int64, err error) {
   356  	defer s.Trace(s.ctx, &err, "Seek(%v,%v)", s.offset, whence)()
   357  	switch whence {
   358  	case io.SeekStart:
   359  		s.offset = offset
   360  	case io.SeekCurrent:
   361  		s.offset += offset
   362  	case io.SeekEnd:
   363  		s.offset = s.asset.Size - offset
   364  	}
   365  	return s.offset, nil
   366  }
   367  
   368  func (a *S3Store) getStreamerCache(asset chat1.Asset) *lru.Cache {
   369  	a.scMutex.Lock()
   370  	defer a.scMutex.Unlock()
   371  	if a.streamCache != nil && a.streamCache.path == asset.Path {
   372  		return a.streamCache.cache
   373  	}
   374  	c, _ := lru.New(20) // store 20MB in memory while streaming
   375  	a.streamCache = &streamCache{
   376  		path:  asset.Path,
   377  		cache: c,
   378  	}
   379  	return c
   380  }
   381  
   382  func (a *S3Store) StreamAsset(ctx context.Context, params chat1.S3Params, asset chat1.Asset,
   383  	signer s3.Signer) (io.ReadSeeker, error) {
   384  	if asset.Key == nil || asset.VerifyKey == nil || asset.EncHash == nil {
   385  		return nil, fmt.Errorf("unencrypted attachments not supported: asset: %#v", asset)
   386  	}
   387  	b := a.getAssetBucket(asset, params, signer)
   388  	ptsize := signencrypt.GetPlaintextSize(asset.Size)
   389  	var xencKey [signencrypt.SecretboxKeySize]byte
   390  	copy(xencKey[:], asset.Key)
   391  	var xverifyKey [ed25519.PublicKeySize]byte
   392  	copy(xverifyKey[:], asset.VerifyKey)
   393  	var nonce [signencrypt.NonceSize]byte
   394  	if asset.Nonce != nil {
   395  		copy(nonce[:], asset.Nonce)
   396  	}
   397  	// Make a ReadSeeker, and pass along the cache if we hit for the given path. We may get
   398  	// a bunch of these calls for a given playback session.
   399  	source := newS3Seeker(ctx, a.G(), asset, b)
   400  	return signencrypt.NewDecodingReadSeeker(ctx, a.G(), source, ptsize, &xencKey, &xverifyKey,
   401  		kbcrypto.SignaturePrefixChatAttachment, &nonce, a.getStreamerCache(asset)), nil
   402  }
   403  
   404  func (a *S3Store) startUpload(ctx context.Context, task *UploadTask, encrypter *SignEncrypter) {
   405  	info := AttachmentInfo{
   406  		ObjectKey: task.S3Params.ObjectKey,
   407  		EncKey:    encrypter.encKey,
   408  		SignKey:   encrypter.signKey,
   409  		VerifyKey: encrypter.verifyKey,
   410  	}
   411  	if err := a.stash.Start(task.stashKey(), info); err != nil {
   412  		a.Debug(ctx, "startUpload: StashStart error: %s", err)
   413  	}
   414  }
   415  
   416  func (a *S3Store) finishUpload(ctx context.Context, task *UploadTask) {
   417  	if err := a.stash.Finish(task.stashKey()); err != nil {
   418  		a.Debug(ctx, "finishUpload: StashFinish error: %s", err)
   419  	}
   420  }
   421  
   422  func (a *S3Store) previousUpload(ctx context.Context, task *UploadTask) *AttachmentInfo {
   423  	info, found, err := a.stash.Lookup(task.stashKey())
   424  	if err != nil {
   425  		a.Debug(ctx, "previousUpload: StashLookup error: %s", err)
   426  		return nil
   427  	}
   428  	if !found {
   429  		return nil
   430  	}
   431  	return &info
   432  }
   433  
   434  func (a *S3Store) regionFromParams(params chat1.S3Params) s3.Region {
   435  	return s3.Region{
   436  		Name:             params.RegionName,
   437  		S3Endpoint:       params.RegionEndpoint,
   438  		S3BucketEndpoint: params.RegionBucketEndpoint,
   439  	}
   440  }
   441  
   442  func (a *S3Store) regionFromAsset(asset chat1.Asset) s3.Region {
   443  	return s3.Region{
   444  		Name:       asset.Region,
   445  		S3Endpoint: asset.Endpoint,
   446  	}
   447  }
   448  
   449  func (a *S3Store) s3Conn(signer s3.Signer, region s3.Region, accessKey string, sessionToken string) s3.Connection {
   450  	conn := a.s3c.New(a.G().ExternalG(), signer, region)
   451  	conn.SetAccessKey(accessKey)
   452  	conn.SetSessionToken(sessionToken)
   453  	return conn
   454  }
   455  
   456  func (a *S3Store) DeleteAssets(ctx context.Context, params chat1.S3Params, signer s3.Signer, assets []chat1.Asset) error {
   457  
   458  	epick := libkb.FirstErrorPicker{}
   459  	for _, asset := range assets {
   460  		if err := a.DeleteAsset(ctx, params, signer, asset); err != nil {
   461  			a.Debug(ctx, "DeleteAssets: DeleteAsset error: %s", err)
   462  			epick.Push(err)
   463  		}
   464  	}
   465  
   466  	return epick.Error()
   467  }
   468  
   469  func (a *S3Store) DeleteAsset(ctx context.Context, params chat1.S3Params, signer s3.Signer, asset chat1.Asset) error {
   470  	region := a.regionFromAsset(asset)
   471  	b := a.s3Conn(signer, region, params.AccessKey, params.Token).Bucket(asset.Bucket)
   472  	return b.Del(ctx, asset.Path)
   473  }