github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/nbs/aws_table_persister_test.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // This file incorporates work covered by the following copyright and
    16  // permission notice:
    17  //
    18  // Copyright 2016 Attic Labs, Inc. All rights reserved.
    19  // Licensed under the Apache License, version 2.0:
    20  // http://www.apache.org/licenses/LICENSE-2.0
    21  
    22  package nbs
    23  
    24  import (
    25  	"context"
    26  	crand "crypto/rand"
    27  	"io"
    28  	"math/rand"
    29  	"sync"
    30  	"testing"
    31  
    32  	"github.com/aws/aws-sdk-go/aws"
    33  	"github.com/aws/aws-sdk-go/aws/request"
    34  	"github.com/aws/aws-sdk-go/service/s3"
    35  	"github.com/aws/aws-sdk-go/service/s3/s3iface"
    36  	"github.com/stretchr/testify/assert"
    37  	"github.com/stretchr/testify/require"
    38  
    39  	"github.com/dolthub/dolt/go/store/hash"
    40  )
    41  
    42  func randomChunks(t *testing.T, r *rand.Rand, sz int) [][]byte {
    43  	buf := make([]byte, sz)
    44  	_, err := io.ReadFull(crand.Reader, buf)
    45  	require.NoError(t, err)
    46  
    47  	var ret [][]byte
    48  	var i int
    49  	for i < len(buf) {
    50  		j := int(r.NormFloat64()*1024 + 4096)
    51  		if i+j >= len(buf) {
    52  			ret = append(ret, buf[i:])
    53  		} else {
    54  			ret = append(ret, buf[i:i+j])
    55  		}
    56  		i += j
    57  	}
    58  
    59  	return ret
    60  }
    61  
    62  func TestRandomChunks(t *testing.T) {
    63  	r := rand.New(rand.NewSource(1024))
    64  	res := randomChunks(t, r, 10)
    65  	assert.Len(t, res, 1)
    66  	res = randomChunks(t, r, 4096+2048)
    67  	assert.Len(t, res, 2)
    68  	res = randomChunks(t, r, 4096+4096)
    69  	assert.Len(t, res, 3)
    70  }
    71  
    72  func TestAWSTablePersisterPersist(t *testing.T) {
    73  	ctx := context.Background()
    74  
    75  	r := rand.New(rand.NewSource(1024))
    76  	const sz15mb = 1 << 20 * 15
    77  	mt := newMemTable(sz15mb)
    78  	testChunks := randomChunks(t, r, 1<<20*12)
    79  	for _, c := range testChunks {
    80  		assert.Equal(t, mt.addChunk(computeAddr(c), c), chunkAdded)
    81  	}
    82  
    83  	var limits5mb = awsLimits{partTarget: 1 << 20 * 5}
    84  	var limits64mb = awsLimits{partTarget: 1 << 20 * 64}
    85  
    86  	t.Run("PersistToS3", func(t *testing.T) {
    87  		testIt := func(t *testing.T, ns string) {
    88  			t.Run("InMultipleParts", func(t *testing.T) {
    89  				assert := assert.New(t)
    90  				s3svc := makeFakeS3(t)
    91  				s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}}
    92  
    93  				src, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
    94  				require.NoError(t, err)
    95  				defer src.close()
    96  
    97  				if assert.True(mustUint32(src.count()) > 0) {
    98  					if r, err := s3svc.readerForTableWithNamespace(ctx, ns, src.hash()); assert.NotNil(r) && assert.NoError(err) {
    99  						assertChunksInReader(testChunks, r, assert)
   100  						r.close()
   101  					}
   102  				}
   103  			})
   104  
   105  			t.Run("InSinglePart", func(t *testing.T) {
   106  				assert := assert.New(t)
   107  
   108  				s3svc := makeFakeS3(t)
   109  				s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits64mb, ns: ns, q: &UnlimitedQuotaProvider{}}
   110  
   111  				src, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
   112  				require.NoError(t, err)
   113  				defer src.close()
   114  				if assert.True(mustUint32(src.count()) > 0) {
   115  					if r, err := s3svc.readerForTableWithNamespace(ctx, ns, src.hash()); assert.NotNil(r) && assert.NoError(err) {
   116  						assertChunksInReader(testChunks, r, assert)
   117  						r.close()
   118  					}
   119  				}
   120  			})
   121  
   122  			t.Run("NoNewChunks", func(t *testing.T) {
   123  				assert := assert.New(t)
   124  
   125  				mt := newMemTable(sz15mb)
   126  				existingTable := newMemTable(sz15mb)
   127  
   128  				for _, c := range testChunks {
   129  					assert.Equal(mt.addChunk(computeAddr(c), c), chunkAdded)
   130  					assert.Equal(existingTable.addChunk(computeAddr(c), c), chunkAdded)
   131  				}
   132  
   133  				s3svc := makeFakeS3(t)
   134  				s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}}
   135  
   136  				src, err := s3p.Persist(context.Background(), mt, existingTable, &Stats{})
   137  				require.NoError(t, err)
   138  				defer src.close()
   139  				assert.True(mustUint32(src.count()) == 0)
   140  
   141  				_, present := s3svc.data[src.hash().String()]
   142  				assert.False(present)
   143  			})
   144  
   145  			t.Run("Abort", func(t *testing.T) {
   146  				assert := assert.New(t)
   147  
   148  				s3svc := &failingFakeS3{makeFakeS3(t), sync.Mutex{}, 1}
   149  				s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}}
   150  
   151  				_, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
   152  				assert.Error(err)
   153  			})
   154  		}
   155  		t.Run("WithoutNamespace", func(t *testing.T) {
   156  			testIt(t, "")
   157  		})
   158  		t.Run("WithNamespace", func(t *testing.T) {
   159  			testIt(t, "a-namespace-here")
   160  		})
   161  	})
   162  }
   163  
   164  type waitOnStoreTableCache struct {
   165  	readers map[hash.Hash]io.ReaderAt
   166  	mu      sync.RWMutex
   167  	storeWG sync.WaitGroup
   168  }
   169  
   170  func (mtc *waitOnStoreTableCache) checkout(h hash.Hash) (io.ReaderAt, error) {
   171  	mtc.mu.RLock()
   172  	defer mtc.mu.RUnlock()
   173  	return mtc.readers[h], nil
   174  }
   175  
   176  func (mtc *waitOnStoreTableCache) checkin(h hash.Hash) error {
   177  	return nil
   178  }
   179  
   180  func (mtc *waitOnStoreTableCache) store(h hash.Hash, data io.Reader, size uint64) error {
   181  	defer mtc.storeWG.Done()
   182  	mtc.mu.Lock()
   183  	defer mtc.mu.Unlock()
   184  	mtc.readers[h] = data.(io.ReaderAt)
   185  	return nil
   186  }
   187  
   188  type failingFakeS3 struct {
   189  	*fakeS3
   190  	mu           sync.Mutex
   191  	numSuccesses int
   192  }
   193  
   194  func (m *failingFakeS3) UploadPartWithContext(ctx aws.Context, input *s3.UploadPartInput, opts ...request.Option) (*s3.UploadPartOutput, error) {
   195  	m.mu.Lock()
   196  	defer m.mu.Unlock()
   197  	if m.numSuccesses > 0 {
   198  		m.numSuccesses--
   199  		return m.fakeS3.UploadPartWithContext(ctx, input)
   200  	}
   201  	return nil, mockAWSError("MalformedXML")
   202  }
   203  
   204  func TestAWSTablePersisterDividePlan(t *testing.T) {
   205  	assert := assert.New(t)
   206  	minPartSize, maxPartSize := uint64(16), uint64(32)
   207  	tooSmall := bytesToChunkSource(t, []byte("a"))
   208  	justRight := bytesToChunkSource(t, []byte("123456789"), []byte("abcdefghi"))
   209  	bigUns := [][]byte{make([]byte, maxPartSize-1), make([]byte, maxPartSize-1)}
   210  	for _, b := range bigUns {
   211  		rand.Read(b)
   212  	}
   213  	tooBig := bytesToChunkSource(t, bigUns...)
   214  
   215  	sources := chunkSources{justRight, tooBig, tooSmall}
   216  	defer func() {
   217  		for _, s := range sources {
   218  			s.close()
   219  		}
   220  	}()
   221  	plan, err := planRangeCopyConjoin(sources, &Stats{})
   222  	require.NoError(t, err)
   223  	copies, manuals, _, err := dividePlan(context.Background(), plan, minPartSize, maxPartSize)
   224  	require.NoError(t, err)
   225  
   226  	perTableDataSize := map[string]int64{}
   227  	for _, c := range copies {
   228  		assert.True(minPartSize <= uint64(c.srcLen))
   229  		assert.True(uint64(c.srcLen) <= maxPartSize)
   230  		totalSize := perTableDataSize[c.name]
   231  		totalSize += c.srcLen
   232  		perTableDataSize[c.name] = totalSize
   233  	}
   234  	assert.Len(perTableDataSize, 2)
   235  	assert.Contains(perTableDataSize, justRight.hash().String())
   236  	assert.Contains(perTableDataSize, tooBig.hash().String())
   237  	ti, err := justRight.index()
   238  	require.NoError(t, err)
   239  	assert.EqualValues(calcChunkRangeSize(ti), perTableDataSize[justRight.hash().String()])
   240  	ti, err = tooBig.index()
   241  	require.NoError(t, err)
   242  	assert.EqualValues(calcChunkRangeSize(ti), perTableDataSize[tooBig.hash().String()])
   243  
   244  	assert.Len(manuals, 1)
   245  	ti, err = tooSmall.index()
   246  	require.NoError(t, err)
   247  	assert.EqualValues(calcChunkRangeSize(ti), manuals[0].end-manuals[0].start)
   248  }
   249  
   250  func TestAWSTablePersisterCalcPartSizes(t *testing.T) {
   251  	assert := assert.New(t)
   252  	min, max := uint64(8*1<<10), uint64(1+(16*1<<10))
   253  
   254  	testPartSizes := func(dataLen uint64) {
   255  		lengths := splitOnMaxSize(dataLen, max)
   256  		var sum int64
   257  		for _, l := range lengths {
   258  			assert.True(uint64(l) >= min)
   259  			assert.True(uint64(l) <= max)
   260  			sum += l
   261  		}
   262  		assert.EqualValues(dataLen, sum)
   263  	}
   264  
   265  	testPartSizes(1 << 20)
   266  	testPartSizes(max + 1)
   267  	testPartSizes(10*max - 1)
   268  	testPartSizes(max + max/2)
   269  }
   270  
   271  func TestAWSTablePersisterConjoinAll(t *testing.T) {
   272  	ctx := context.Background()
   273  	const sz5mb = 1 << 20 * 5
   274  	targetPartSize := uint64(sz5mb)
   275  	minPartSize, maxPartSize := targetPartSize, 5*targetPartSize
   276  
   277  	rl := make(chan struct{}, 8)
   278  	defer close(rl)
   279  
   280  	newPersister := func(s3svc s3iface.S3API) awsTablePersister {
   281  		return awsTablePersister{
   282  			s3svc,
   283  			"bucket",
   284  			rl,
   285  			awsLimits{targetPartSize, minPartSize, maxPartSize},
   286  			"",
   287  			&UnlimitedQuotaProvider{},
   288  		}
   289  	}
   290  
   291  	var smallChunks [][]byte
   292  	rnd := rand.New(rand.NewSource(0))
   293  	for smallChunkTotal := uint64(0); smallChunkTotal <= uint64(minPartSize); {
   294  		small := make([]byte, minPartSize/5)
   295  		rnd.Read(small)
   296  		src := bytesToChunkSource(t, small)
   297  		smallChunks = append(smallChunks, small)
   298  		ti, err := src.index()
   299  		require.NoError(t, err)
   300  		smallChunkTotal += calcChunkRangeSize(ti)
   301  		ti.Close()
   302  	}
   303  
   304  	t.Run("Small", func(t *testing.T) {
   305  		makeSources := func(s3p awsTablePersister, chunks [][]byte) (sources chunkSources) {
   306  			for i := 0; i < len(chunks); i++ {
   307  				mt := newMemTable(uint64(2 * targetPartSize))
   308  				mt.addChunk(computeAddr(chunks[i]), chunks[i])
   309  				cs, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
   310  				require.NoError(t, err)
   311  				sources = append(sources, cs)
   312  			}
   313  			return
   314  		}
   315  
   316  		t.Run("TotalUnderMinSize", func(t *testing.T) {
   317  			assert := assert.New(t)
   318  			s3svc := makeFakeS3(t)
   319  			s3p := newPersister(s3svc)
   320  
   321  			chunks := smallChunks[:len(smallChunks)-1]
   322  			sources := makeSources(s3p, chunks)
   323  			src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{})
   324  			require.NoError(t, err)
   325  			defer src.close()
   326  			for _, s := range sources {
   327  				s.close()
   328  			}
   329  
   330  			if assert.True(mustUint32(src.count()) > 0) {
   331  				if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) {
   332  					assertChunksInReader(chunks, r, assert)
   333  					r.close()
   334  				}
   335  			}
   336  		})
   337  
   338  		t.Run("TotalOverMinSize", func(t *testing.T) {
   339  			assert := assert.New(t)
   340  			s3svc := makeFakeS3(t)
   341  			s3p := newPersister(s3svc)
   342  
   343  			sources := makeSources(s3p, smallChunks)
   344  			src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{})
   345  			require.NoError(t, err)
   346  			defer src.close()
   347  			for _, s := range sources {
   348  				s.close()
   349  			}
   350  
   351  			if assert.True(mustUint32(src.count()) > 0) {
   352  				if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) {
   353  					assertChunksInReader(smallChunks, r, assert)
   354  					r.close()
   355  				}
   356  			}
   357  		})
   358  	})
   359  
   360  	bigUns1 := [][]byte{make([]byte, maxPartSize-1), make([]byte, maxPartSize-1)}
   361  	bigUns2 := [][]byte{make([]byte, maxPartSize-1), make([]byte, maxPartSize-1)}
   362  	for _, bu := range [][][]byte{bigUns1, bigUns2} {
   363  		for _, b := range bu {
   364  			rand.Read(b)
   365  		}
   366  	}
   367  
   368  	t.Run("AllOverMax", func(t *testing.T) {
   369  		assert := assert.New(t)
   370  		s3svc := makeFakeS3(t)
   371  		s3p := newPersister(s3svc)
   372  
   373  		// Make 2 chunk sources that each have >maxPartSize chunk data
   374  		sources := make(chunkSources, 2)
   375  		for i, bu := range [][][]byte{bigUns1, bigUns2} {
   376  			mt := newMemTable(uint64(2 * maxPartSize))
   377  			for _, b := range bu {
   378  				mt.addChunk(computeAddr(b), b)
   379  			}
   380  
   381  			var err error
   382  			sources[i], err = s3p.Persist(context.Background(), mt, nil, &Stats{})
   383  			require.NoError(t, err)
   384  		}
   385  		src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{})
   386  		require.NoError(t, err)
   387  		defer src.close()
   388  		for _, s := range sources {
   389  			s.close()
   390  		}
   391  
   392  		if assert.True(mustUint32(src.count()) > 0) {
   393  			if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) {
   394  				assertChunksInReader(bigUns1, r, assert)
   395  				assertChunksInReader(bigUns2, r, assert)
   396  				r.close()
   397  			}
   398  		}
   399  	})
   400  
   401  	t.Run("SomeOverMax", func(t *testing.T) {
   402  		assert := assert.New(t)
   403  		s3svc := makeFakeS3(t)
   404  		s3p := newPersister(s3svc)
   405  
   406  		// Add one chunk source that has >maxPartSize data
   407  		mtb := newMemTable(uint64(2 * maxPartSize))
   408  		for _, b := range bigUns1 {
   409  			mtb.addChunk(computeAddr(b), b)
   410  		}
   411  
   412  		// Follow up with a chunk source where minPartSize < data size < maxPartSize
   413  		medChunks := make([][]byte, 2)
   414  		mt := newMemTable(uint64(2 * maxPartSize))
   415  		for i := range medChunks {
   416  			medChunks[i] = make([]byte, minPartSize+1)
   417  			rand.Read(medChunks[i])
   418  			mt.addChunk(computeAddr(medChunks[i]), medChunks[i])
   419  		}
   420  		cs1, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
   421  		require.NoError(t, err)
   422  		cs2, err := s3p.Persist(context.Background(), mtb, nil, &Stats{})
   423  		require.NoError(t, err)
   424  		sources := chunkSources{cs1, cs2}
   425  
   426  		src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{})
   427  		require.NoError(t, err)
   428  		defer src.close()
   429  		for _, s := range sources {
   430  			s.close()
   431  		}
   432  
   433  		if assert.True(mustUint32(src.count()) > 0) {
   434  			if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) {
   435  				assertChunksInReader(bigUns1, r, assert)
   436  				assertChunksInReader(medChunks, r, assert)
   437  				r.close()
   438  			}
   439  		}
   440  	})
   441  
   442  	t.Run("Mix", func(t *testing.T) {
   443  		assert := assert.New(t)
   444  		s3svc := makeFakeS3(t)
   445  		s3p := newPersister(s3svc)
   446  
   447  		// Start with small tables. Since total > minPartSize, will require more than one part to upload.
   448  		sources := make(chunkSources, len(smallChunks))
   449  		for i := 0; i < len(smallChunks); i++ {
   450  			mt := newMemTable(uint64(2 * targetPartSize))
   451  			mt.addChunk(computeAddr(smallChunks[i]), smallChunks[i])
   452  			var err error
   453  			sources[i], err = s3p.Persist(context.Background(), mt, nil, &Stats{})
   454  			require.NoError(t, err)
   455  		}
   456  
   457  		// Now, add a table with big chunks that will require more than one upload copy part.
   458  		mt := newMemTable(uint64(2 * maxPartSize))
   459  		for _, b := range bigUns1 {
   460  			mt.addChunk(computeAddr(b), b)
   461  		}
   462  
   463  		var err error
   464  		cs, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
   465  		require.NoError(t, err)
   466  		sources = append(sources, cs)
   467  
   468  		// Last, some tables that should be directly upload-copyable
   469  		medChunks := make([][]byte, 2)
   470  		mt = newMemTable(uint64(2 * maxPartSize))
   471  		for i := range medChunks {
   472  			medChunks[i] = make([]byte, minPartSize+1)
   473  			rand.Read(medChunks[i])
   474  			mt.addChunk(computeAddr(medChunks[i]), medChunks[i])
   475  		}
   476  
   477  		cs, err = s3p.Persist(context.Background(), mt, nil, &Stats{})
   478  		require.NoError(t, err)
   479  		sources = append(sources, cs)
   480  
   481  		src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{})
   482  		require.NoError(t, err)
   483  		defer src.close()
   484  		for _, s := range sources {
   485  			s.close()
   486  		}
   487  
   488  		if assert.True(mustUint32(src.count()) > 0) {
   489  			if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) {
   490  				assertChunksInReader(smallChunks, r, assert)
   491  				assertChunksInReader(bigUns1, r, assert)
   492  				assertChunksInReader(medChunks, r, assert)
   493  				r.close()
   494  			}
   495  		}
   496  	})
   497  }
   498  
   499  func bytesToChunkSource(t *testing.T, bs ...[]byte) chunkSource {
   500  	ctx := context.Background()
   501  	sum := 0
   502  	for _, b := range bs {
   503  		sum += len(b)
   504  	}
   505  	maxSize := maxTableSize(uint64(len(bs)), uint64(sum))
   506  	buff := make([]byte, maxSize)
   507  	tw := newTableWriter(buff, nil)
   508  	for _, b := range bs {
   509  		tw.addChunk(computeAddr(b), b)
   510  	}
   511  	tableSize, name, err := tw.finish()
   512  	require.NoError(t, err)
   513  	data := buff[:tableSize]
   514  	ti, err := parseTableIndexByCopy(ctx, data, &UnlimitedQuotaProvider{})
   515  	require.NoError(t, err)
   516  	rdr, err := newTableReader(ti, tableReaderAtFromBytes(data), fileBlockSize)
   517  	require.NoError(t, err)
   518  	return chunkSourceAdapter{rdr, name}
   519  }