github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/chat/attachments/store_test.go (about)

     1  package attachments
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/sha256"
     7  	"encoding/hex"
     8  	"io"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/keybase/client/go/externalstest"
    13  	"github.com/keybase/client/go/libkb"
    14  
    15  	"github.com/keybase/client/go/chat/globals"
    16  	"github.com/keybase/client/go/chat/s3"
    17  	"github.com/keybase/client/go/chat/signencrypt"
    18  	"github.com/keybase/client/go/chat/storage"
    19  	"github.com/keybase/client/go/protocol/chat1"
    20  	"github.com/stretchr/testify/require"
    21  	"golang.org/x/net/context"
    22  )
    23  
    24  const MB int64 = 1024 * 1024
    25  
    26  func TestSignEncrypter(t *testing.T) {
    27  	e := NewSignEncrypter()
    28  	el := e.EncryptedLen(100)
    29  	if el != 180 {
    30  		t.Errorf("enc len: %d, expected 180", el)
    31  	}
    32  
    33  	el = e.EncryptedLen(50 * 1024 * 1024)
    34  	if el != 52432880 {
    35  		t.Errorf("enc len: %d, expected 52432880", el)
    36  	}
    37  
    38  	pt := "plain text"
    39  	er, err := e.Encrypt(strings.NewReader(pt))
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  	ct, err := io.ReadAll(er)
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  
    48  	if string(ct) == pt {
    49  		t.Fatal("Encrypt did not change plaintext")
    50  	}
    51  
    52  	d := NewSignDecrypter()
    53  	dr := d.Decrypt(bytes.NewReader(ct), e.EncryptKey(), e.VerifyKey())
    54  	ptOut, err := io.ReadAll(dr)
    55  	if err != nil {
    56  		t.Fatal(err)
    57  	}
    58  	if string(ptOut) != pt {
    59  		t.Errorf("decrypted ciphertext doesn't match plaintext: %q, expected %q", ptOut, pt)
    60  	}
    61  
    62  	// reuse e to do another Encrypt, make sure keys change:
    63  	firstEncKey := e.EncryptKey()
    64  	firstVerifyKey := e.VerifyKey()
    65  
    66  	er2, err := e.Encrypt(strings.NewReader(pt))
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	ct2, err := io.ReadAll(er2)
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  
    75  	if string(ct2) == pt {
    76  		t.Fatal("Encrypt did not change plaintext")
    77  	}
    78  	if bytes.Equal(ct, ct2) {
    79  		t.Fatal("second Encrypt result same as first")
    80  	}
    81  	if bytes.Equal(firstEncKey, e.EncryptKey()) {
    82  		t.Fatal("first enc key reused")
    83  	}
    84  	if bytes.Equal(firstVerifyKey, e.VerifyKey()) {
    85  		t.Fatal("first verify key reused")
    86  	}
    87  
    88  	dr2 := d.Decrypt(bytes.NewReader(ct2), e.EncryptKey(), e.VerifyKey())
    89  	ptOut2, err := io.ReadAll(dr2)
    90  	if err != nil {
    91  		t.Fatal(err)
    92  	}
    93  	if string(ptOut2) != pt {
    94  		t.Errorf("decrypted ciphertext doesn't match plaintext: %q, expected %q", ptOut2, pt)
    95  	}
    96  }
    97  
    98  func testStoreMultis(t *testing.T, s *S3Store) []*s3.MemMulti {
    99  	m, ok := s.s3c.(*s3.Mem)
   100  	if !ok {
   101  		t.Fatalf("not s3.Mem: %T", s.s3c)
   102  	}
   103  	// get *MemConn directly
   104  	c := m.NewMemConn()
   105  	return c.AllMultis()
   106  }
   107  
   108  func assertNumMultis(t *testing.T, s *S3Store, n int) {
   109  	numMultis := len(testStoreMultis(t, s))
   110  	if numMultis != n {
   111  		t.Errorf("number of s3 multis: %d, expected %d", numMultis, n)
   112  	}
   113  }
   114  
   115  func getMulti(t *testing.T, s *S3Store, index int) *s3.MemMulti {
   116  	all := testStoreMultis(t, s)
   117  	return all[index]
   118  }
   119  
   120  func assertNumParts(t *testing.T, s *S3Store, index, n int) {
   121  	m := getMulti(t, s, index)
   122  	p, err := m.ListParts(context.Background())
   123  	if err != nil {
   124  		t.Fatal(err)
   125  	}
   126  	if len(p) != n {
   127  		t.Errorf("num parts in multi: %d, expected %d", len(p), n)
   128  	}
   129  }
   130  
   131  func assertNumPutParts(t *testing.T, s *S3Store, index, calls int) {
   132  	m := getMulti(t, s, index)
   133  	if m.NumPutParts() != calls {
   134  		t.Errorf("num PutPart calls: %d, expected %d", m.NumPutParts(), calls)
   135  	}
   136  }
   137  
   138  func randBytes(t *testing.T, n int) []byte {
   139  	buf := make([]byte, n)
   140  	if _, err := rand.Read(buf); err != nil {
   141  		t.Fatal(err)
   142  	}
   143  	return buf
   144  }
   145  
   146  func randString(t *testing.T, nbytes int) string {
   147  	return hex.EncodeToString(randBytes(t, nbytes))
   148  }
   149  
   150  type ptsigner struct{}
   151  
   152  func (p *ptsigner) Sign(payload []byte) ([]byte, error) {
   153  	s := sha256.Sum256(payload)
   154  	return s[:], nil
   155  }
   156  
   157  type bytesReadResetter struct {
   158  	data   []byte
   159  	r      io.Reader
   160  	resets int
   161  }
   162  
   163  func newBytesReadResetter(d []byte) *bytesReadResetter {
   164  	return &bytesReadResetter{
   165  		data: d,
   166  		r:    bytes.NewReader(d),
   167  	}
   168  }
   169  
   170  func (b *bytesReadResetter) Read(p []byte) (n int, err error) {
   171  	return b.r.Read(p)
   172  }
   173  
   174  func (b *bytesReadResetter) Reset() error {
   175  	b.resets++
   176  	b.r = bytes.NewReader(b.data)
   177  	return nil
   178  }
   179  
   180  func makeUploadTask(t *testing.T, size int64) (plaintext []byte, task *UploadTask) {
   181  	plaintext = randBytes(t, int(size))
   182  	outboxID, _ := storage.NewOutboxID()
   183  	task = &UploadTask{
   184  		S3Params: chat1.S3Params{
   185  			Bucket:    "upload-test",
   186  			ObjectKey: randString(t, 8),
   187  		},
   188  		Filename:       randString(t, 8),
   189  		FileSize:       size,
   190  		Plaintext:      newBytesReadResetter(plaintext),
   191  		S3Signer:       &ptsigner{},
   192  		ConversationID: randBytes(t, 16),
   193  		OutboxID:       outboxID,
   194  	}
   195  	return plaintext, task
   196  }
   197  
   198  func TestUploadAssetSmall(t *testing.T) {
   199  	tc := externalstest.SetupTest(t, "chat_store", 1)
   200  	defer tc.Cleanup()
   201  	g := globals.NewContext(tc.G, &globals.ChatContext{})
   202  
   203  	s := NewStoreTesting(g, nil)
   204  	ctx := context.Background()
   205  	plaintext, task := makeUploadTask(t, 1*MB)
   206  	a, err := s.UploadAsset(ctx, task, io.Discard)
   207  	if err != nil {
   208  		t.Fatal(err)
   209  	}
   210  
   211  	var buf bytes.Buffer
   212  	if err = s.DownloadAsset(ctx, task.S3Params, a, &buf, task.S3Signer, nil); err != nil {
   213  		t.Fatal(err)
   214  	}
   215  	if !bytes.Equal(plaintext, buf.Bytes()) {
   216  		t.Errorf("downloaded asset did not match uploaded asset")
   217  	}
   218  
   219  	// small uploads should not (cannot) use multi interface to s3:
   220  	assertNumMultis(t, s, 0)
   221  }
   222  
   223  func TestUploadAssetLarge(t *testing.T) {
   224  	tc := externalstest.SetupTest(t, "chat_store", 1)
   225  	defer tc.Cleanup()
   226  	g := globals.NewContext(tc.G, &globals.ChatContext{})
   227  
   228  	s := NewStoreTesting(g, nil)
   229  	ctx := context.Background()
   230  	plaintext, task := makeUploadTask(t, 12*MB)
   231  	a, err := s.UploadAsset(ctx, task, io.Discard)
   232  	if err != nil {
   233  		t.Fatal(err)
   234  	}
   235  
   236  	var buf bytes.Buffer
   237  	if err = s.DownloadAsset(ctx, task.S3Params, a, &buf, task.S3Signer, nil); err != nil {
   238  		t.Fatal(err)
   239  	}
   240  	if !bytes.Equal(plaintext, buf.Bytes()) {
   241  		t.Errorf("downloaded asset did not match uploaded asset")
   242  	}
   243  
   244  	// large uploads should use multi interface to s3:
   245  	assertNumMultis(t, s, 1)
   246  }
   247  
   248  // dumbBuffer wraps a bytes.Buffer so io.Copy doesn't use WriteTo and we get a better test
   249  type dumbBuffer struct {
   250  	buf bytes.Buffer
   251  }
   252  
   253  func (d *dumbBuffer) Write(b []byte) (n int, err error) {
   254  	return d.buf.Write(b)
   255  }
   256  
   257  func (d *dumbBuffer) Bytes() []byte {
   258  	return d.buf.Bytes()
   259  }
   260  
   261  func newDumbBuffer() *dumbBuffer {
   262  	return &dumbBuffer{}
   263  }
   264  
   265  func TestStreamAsset(t *testing.T) {
   266  	tc := externalstest.SetupTest(t, "chat_store", 1)
   267  	defer tc.Cleanup()
   268  	g := globals.NewContext(tc.G, &globals.ChatContext{})
   269  
   270  	s := NewStoreTesting(g, nil)
   271  	ctx := context.Background()
   272  
   273  	testCase := func(mb, kb int64) {
   274  		total := mb*MB + kb
   275  		t.Logf("total: %d mb: %d kb: %d", total, mb, kb)
   276  		plaintext, task := makeUploadTask(t, total)
   277  		a, err := s.UploadAsset(ctx, task, io.Discard)
   278  		require.NoError(t, err)
   279  
   280  		// basic
   281  		var buf bytes.Buffer
   282  		t.Logf("basic")
   283  		s.streamCache = nil
   284  		rs, err := s.StreamAsset(ctx, task.S3Params, a, task.S3Signer)
   285  		require.NoError(t, err)
   286  		_, err = io.Copy(&buf, rs)
   287  		require.NoError(t, err)
   288  		require.True(t, bytes.Equal(plaintext, buf.Bytes()))
   289  		// use the cache
   290  		buf.Reset()
   291  		rs, err = s.StreamAsset(ctx, task.S3Params, a, task.S3Signer)
   292  		require.NoError(t, err)
   293  		_, err = io.Copy(&buf, rs)
   294  		require.NoError(t, err)
   295  		require.True(t, bytes.Equal(plaintext, buf.Bytes()))
   296  
   297  		// seek to half and copy
   298  		t.Logf("half")
   299  		dbuf := newDumbBuffer()
   300  		s.streamCache = nil
   301  		rs, err = s.StreamAsset(ctx, task.S3Params, a, task.S3Signer)
   302  		require.NoError(t, err)
   303  		_, err = rs.Seek(total/2, io.SeekStart)
   304  		require.NoError(t, err)
   305  		_, err = io.Copy(dbuf, rs)
   306  		require.NoError(t, err)
   307  		require.True(t, bytes.Equal(plaintext[total/2:], dbuf.Bytes()))
   308  
   309  		// use a fixed size buffer (like video playback)
   310  		t.Logf("buffer")
   311  		dbuf = newDumbBuffer()
   312  		s.streamCache = nil
   313  		scratch := make([]byte, 64*1024)
   314  		rs, err = s.StreamAsset(ctx, task.S3Params, a, task.S3Signer)
   315  		require.NoError(t, err)
   316  		_, err = io.CopyBuffer(dbuf, rs, scratch)
   317  		require.NoError(t, err)
   318  		require.True(t, bytes.Equal(plaintext, dbuf.Bytes()))
   319  	}
   320  
   321  	testCase(2, 0)
   322  	testCase(2, 400)
   323  	testCase(12, 0)
   324  	testCase(12, 543)
   325  }
   326  
   327  type uploader struct {
   328  	t             *testing.T
   329  	s             *S3Store
   330  	encKey        []byte
   331  	sigKey        []byte
   332  	plaintext     []byte
   333  	task          *UploadTask
   334  	breader       *bytesReadResetter
   335  	partialEncKey []byte // keys from UploadPartial
   336  	partialSigKey []byte
   337  	fullEncKey    []byte // keys from UploadResume
   338  	fullSigKey    []byte
   339  }
   340  
   341  func newUploader(t *testing.T, size int64, gc *libkb.GlobalContext) *uploader {
   342  	u := &uploader{t: t}
   343  	g := globals.NewContext(gc, &globals.ChatContext{})
   344  	u.s = NewStoreTesting(g, u.keyTracker)
   345  	u.plaintext, u.task = makeUploadTask(t, size)
   346  	return u
   347  }
   348  
   349  func (u *uploader) keyTracker(e, s []byte) {
   350  	u.encKey = e
   351  	u.sigKey = s
   352  }
   353  
   354  func (u *uploader) UploadResume() chat1.Asset {
   355  	u.s.blockLimit = 0
   356  	a, err := u.s.UploadAsset(context.Background(), u.task, io.Discard)
   357  	if err != nil {
   358  		u.t.Fatalf("expected second UploadAsset call to work, got: %s", err)
   359  	}
   360  	if a.Size != signencrypt.GetSealedSize(int64(len(u.plaintext))) {
   361  		u.t.Errorf("uploaded asset size: %d, expected %d", a.Size,
   362  			signencrypt.GetSealedSize(int64(len(u.plaintext))))
   363  	}
   364  	u.fullEncKey = u.encKey
   365  	u.fullSigKey = u.sigKey
   366  
   367  	// a resumed upload should reuse existing multi, so there should only be one:
   368  	assertNumMultis(u.t, u.s, 1)
   369  
   370  	// after resumed upload, all parts should have been uploaded
   371  	numParts := (int64(len(u.plaintext)) / (5 * MB)) + 1
   372  	assertNumParts(u.t, u.s, 0, int(numParts))
   373  
   374  	return a
   375  }
   376  
   377  func (u *uploader) UploadPartial(blocks int) {
   378  	u.s.blockLimit = blocks
   379  
   380  	_, err := u.s.UploadAsset(context.Background(), u.task, io.Discard)
   381  	if err == nil {
   382  		u.t.Fatal("expected incomplete upload to have error")
   383  	}
   384  
   385  	assertNumParts(u.t, u.s, 0, blocks)
   386  	assertNumPutParts(u.t, u.s, 0, blocks)
   387  
   388  	u.partialEncKey = u.encKey
   389  	u.partialSigKey = u.sigKey
   390  }
   391  
   392  func (u *uploader) ResetReader() {
   393  	u.s.blockLimit = 0
   394  	u.breader = newBytesReadResetter(u.plaintext)
   395  	u.task.Plaintext = u.breader
   396  }
   397  
   398  func (u *uploader) ResetHash() {
   399  	u.task.taskHash = nil
   400  }
   401  
   402  func (u *uploader) DownloadAndMatch(a chat1.Asset) {
   403  	var buf bytes.Buffer
   404  	if err := u.s.DownloadAsset(context.Background(), u.task.S3Params, a, &buf, u.task.S3Signer, nil); err != nil {
   405  		u.t.Fatal(err)
   406  	}
   407  	plaintextDownload := buf.Bytes()
   408  	if len(plaintextDownload) != len(u.plaintext) {
   409  		u.t.Errorf("downloaded asset len: %d, expected %d", len(plaintextDownload), len(u.plaintext))
   410  	}
   411  	if !bytes.Equal(u.plaintext, plaintextDownload) {
   412  		u.t.Errorf("downloaded asset did not match uploaded asset (%x v. %x)", plaintextDownload[:10], u.plaintext[:10])
   413  	}
   414  }
   415  
   416  func (u *uploader) AssertKeysChanged() {
   417  	if bytes.Equal(u.partialEncKey, u.fullEncKey) {
   418  		u.t.Errorf("partial enc key and full enc key match: enc key reused")
   419  	}
   420  	if bytes.Equal(u.partialSigKey, u.fullSigKey) {
   421  		u.t.Errorf("partial sig key and full sig key match: sig key reused")
   422  	}
   423  }
   424  
   425  func (u *uploader) AssertKeysReused() {
   426  	if !bytes.Equal(u.partialEncKey, u.fullEncKey) {
   427  		u.t.Errorf("partial enc key and full enc key different: enc key not reused")
   428  	}
   429  	if !bytes.Equal(u.partialSigKey, u.fullSigKey) {
   430  		u.t.Errorf("partial sig key and full sig key different: sig key not reused")
   431  	}
   432  }
   433  
   434  func (u *uploader) AssertNumPutParts(n int) {
   435  	assertNumPutParts(u.t, u.s, 0, n)
   436  }
   437  
   438  func (u *uploader) AssertNumResets(n int) {
   439  	if u.breader.resets != n {
   440  		u.t.Errorf("stream resets: %d, expected %d", u.breader.resets, n)
   441  	}
   442  }
   443  
   444  func (u *uploader) AssertNumAborts(n int) {
   445  	if u.s.aborts != n {
   446  		u.t.Errorf("aborts: %d, expected %d", u.s.aborts, n)
   447  	}
   448  }
   449  
   450  // Test uploading part of an asset, then resuming at a later point in time.
   451  // The asset does not change between the attempts.
   452  func TestUploadAssetResumeOK(t *testing.T) {
   453  	tc := externalstest.SetupTest(t, "chat_store", 1)
   454  	defer tc.Cleanup()
   455  
   456  	u := newUploader(t, 12*MB, tc.G)
   457  
   458  	// upload 2 parts of the asset
   459  	u.UploadPartial(2)
   460  
   461  	// resume the upload
   462  	u.ResetReader()
   463  	u.ResetHash()
   464  	a := u.UploadResume()
   465  
   466  	// download the asset
   467  	u.DownloadAndMatch(a)
   468  
   469  	// there should only be 3 calls to PutPart (2 in attempt 1, 1 in attempt 2).
   470  	u.AssertNumPutParts(3)
   471  
   472  	// keys should be reused
   473  	u.AssertKeysReused()
   474  
   475  	// no resets happen here
   476  	u.AssertNumResets(0)
   477  
   478  	// there should have been no aborts
   479  	u.AssertNumAborts(0)
   480  }
   481  
   482  // Test uploading part of an asset, then resuming at a later point in time.
   483  // The asset changes between the attempts.
   484  func TestUploadAssetResumeChange(t *testing.T) {
   485  	tc := externalstest.SetupTest(t, "chat_store", 1)
   486  	defer tc.Cleanup()
   487  
   488  	size := 12 * MB
   489  	u := newUploader(t, size, tc.G)
   490  
   491  	// upload 2 parts of the asset
   492  	u.UploadPartial(2)
   493  
   494  	// try again, changing the file and the hash (but same destination on s3):
   495  	// this simulates the file changing between upload attempt 1 and this attempt.
   496  	u.plaintext = randBytes(t, int(size))
   497  	u.ResetReader()
   498  	u.ResetHash()
   499  	a := u.UploadResume()
   500  	u.DownloadAndMatch(a)
   501  
   502  	// there should be 5 total calls to PutPart (2 in attempt 1, 3 in attempt 2).
   503  	u.AssertNumPutParts(5)
   504  
   505  	// keys should not be reused
   506  	u.AssertKeysChanged()
   507  
   508  	// only reset of second attempt should be after plaintext hash
   509  	u.AssertNumResets(1)
   510  
   511  	// we get one abort since outboxID doesn't change with the file
   512  	u.AssertNumAborts(1)
   513  }
   514  
   515  // Test uploading part of an asset, then resuming at a later point in time.
   516  // The asset changes after the plaintext hash is calculated in the resume attempt.
   517  func TestUploadAssetResumeRestart(t *testing.T) {
   518  	tc := externalstest.SetupTest(t, "chat_store", 1)
   519  	defer tc.Cleanup()
   520  
   521  	u := newUploader(t, 12*MB, tc.G)
   522  
   523  	// upload 2 parts of the asset
   524  	u.UploadPartial(2)
   525  
   526  	// try again, changing only one byte of the file (and not touching the plaintext hash).
   527  	// this should result in full restart of upload with new keys
   528  	u.plaintext[0] ^= 0x10
   529  	u.ResetReader()
   530  	// not calling u.ResetHash() here to simulate a change after the plaintext hash is
   531  	// calculated
   532  	a := u.UploadResume()
   533  	u.DownloadAndMatch(a)
   534  
   535  	// there should be 5 total calls to PutPart (2 in attempt 1, 3 in attempt 2).
   536  	u.AssertNumPutParts(5)
   537  
   538  	// keys should not be reused
   539  	u.AssertKeysChanged()
   540  
   541  	// one reset on the abort
   542  	u.AssertNumResets(1)
   543  
   544  	// there should have been one abort
   545  	u.AssertNumAborts(1)
   546  }