github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/nbs/aws_table_persister_test.go (about) 1 // Copyright 2019 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // This file incorporates work covered by the following copyright and 16 // permission notice: 17 // 18 // Copyright 2016 Attic Labs, Inc. All rights reserved. 19 // Licensed under the Apache License, version 2.0: 20 // http://www.apache.org/licenses/LICENSE-2.0 21 22 package nbs 23 24 import ( 25 "context" 26 crand "crypto/rand" 27 "io" 28 "math/rand" 29 "sync" 30 "testing" 31 32 "github.com/aws/aws-sdk-go/aws" 33 "github.com/aws/aws-sdk-go/aws/request" 34 "github.com/aws/aws-sdk-go/service/s3" 35 "github.com/aws/aws-sdk-go/service/s3/s3iface" 36 "github.com/stretchr/testify/assert" 37 "github.com/stretchr/testify/require" 38 39 "github.com/dolthub/dolt/go/store/hash" 40 ) 41 42 func randomChunks(t *testing.T, r *rand.Rand, sz int) [][]byte { 43 buf := make([]byte, sz) 44 _, err := io.ReadFull(crand.Reader, buf) 45 require.NoError(t, err) 46 47 var ret [][]byte 48 var i int 49 for i < len(buf) { 50 j := int(r.NormFloat64()*1024 + 4096) 51 if i+j >= len(buf) { 52 ret = append(ret, buf[i:]) 53 } else { 54 ret = append(ret, buf[i:i+j]) 55 } 56 i += j 57 } 58 59 return ret 60 } 61 62 func TestRandomChunks(t *testing.T) { 63 r := rand.New(rand.NewSource(1024)) 64 res := randomChunks(t, r, 10) 65 assert.Len(t, res, 1) 66 res = randomChunks(t, r, 4096+2048) 67 assert.Len(t, res, 2) 68 res = randomChunks(t, r, 4096+4096) 69 assert.Len(t, res, 3) 70 } 71 72 func TestAWSTablePersisterPersist(t *testing.T) { 73 ctx := context.Background() 74 75 r := rand.New(rand.NewSource(1024)) 76 const sz15mb = 1 << 20 * 15 77 mt := newMemTable(sz15mb) 78 testChunks := randomChunks(t, r, 1<<20*12) 79 for _, c := range testChunks { 80 assert.Equal(t, mt.addChunk(computeAddr(c), c), chunkAdded) 81 } 82 83 var limits5mb = awsLimits{partTarget: 1 << 20 * 5} 84 var limits64mb = awsLimits{partTarget: 1 << 20 * 64} 85 86 t.Run("PersistToS3", func(t *testing.T) { 87 testIt := func(t *testing.T, ns string) { 88 t.Run("InMultipleParts", func(t *testing.T) { 89 assert := assert.New(t) 90 s3svc := makeFakeS3(t) 91 s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} 92 93 src, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) 94 require.NoError(t, err) 95 defer src.close() 96 97 if assert.True(mustUint32(src.count()) > 0) { 98 if r, err := s3svc.readerForTableWithNamespace(ctx, ns, src.hash()); assert.NotNil(r) && assert.NoError(err) { 99 assertChunksInReader(testChunks, r, assert) 100 r.close() 101 } 102 } 103 }) 104 105 t.Run("InSinglePart", func(t *testing.T) { 106 assert := assert.New(t) 107 108 s3svc := makeFakeS3(t) 109 s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits64mb, ns: ns, q: &UnlimitedQuotaProvider{}} 110 111 src, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) 112 require.NoError(t, err) 113 defer src.close() 114 if assert.True(mustUint32(src.count()) > 0) { 115 if r, err := s3svc.readerForTableWithNamespace(ctx, ns, src.hash()); assert.NotNil(r) && assert.NoError(err) { 116 assertChunksInReader(testChunks, r, assert) 117 r.close() 118 } 119 } 120 }) 121 122 t.Run("NoNewChunks", func(t *testing.T) { 123 assert := assert.New(t) 124 125 mt := newMemTable(sz15mb) 126 existingTable := newMemTable(sz15mb) 127 128 for _, c := range testChunks { 129 assert.Equal(mt.addChunk(computeAddr(c), c), chunkAdded) 130 assert.Equal(existingTable.addChunk(computeAddr(c), c), chunkAdded) 131 } 132 133 s3svc := makeFakeS3(t) 134 s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} 135 136 src, err := s3p.Persist(context.Background(), mt, existingTable, &Stats{}) 137 require.NoError(t, err) 138 defer src.close() 139 assert.True(mustUint32(src.count()) == 0) 140 141 _, present := s3svc.data[src.hash().String()] 142 assert.False(present) 143 }) 144 145 t.Run("Abort", func(t *testing.T) { 146 assert := assert.New(t) 147 148 s3svc := &failingFakeS3{makeFakeS3(t), sync.Mutex{}, 1} 149 s3p := awsTablePersister{s3: s3svc, bucket: "bucket", limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}} 150 151 _, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) 152 assert.Error(err) 153 }) 154 } 155 t.Run("WithoutNamespace", func(t *testing.T) { 156 testIt(t, "") 157 }) 158 t.Run("WithNamespace", func(t *testing.T) { 159 testIt(t, "a-namespace-here") 160 }) 161 }) 162 } 163 164 type waitOnStoreTableCache struct { 165 readers map[hash.Hash]io.ReaderAt 166 mu sync.RWMutex 167 storeWG sync.WaitGroup 168 } 169 170 func (mtc *waitOnStoreTableCache) checkout(h hash.Hash) (io.ReaderAt, error) { 171 mtc.mu.RLock() 172 defer mtc.mu.RUnlock() 173 return mtc.readers[h], nil 174 } 175 176 func (mtc *waitOnStoreTableCache) checkin(h hash.Hash) error { 177 return nil 178 } 179 180 func (mtc *waitOnStoreTableCache) store(h hash.Hash, data io.Reader, size uint64) error { 181 defer mtc.storeWG.Done() 182 mtc.mu.Lock() 183 defer mtc.mu.Unlock() 184 mtc.readers[h] = data.(io.ReaderAt) 185 return nil 186 } 187 188 type failingFakeS3 struct { 189 *fakeS3 190 mu sync.Mutex 191 numSuccesses int 192 } 193 194 func (m *failingFakeS3) UploadPartWithContext(ctx aws.Context, input *s3.UploadPartInput, opts ...request.Option) (*s3.UploadPartOutput, error) { 195 m.mu.Lock() 196 defer m.mu.Unlock() 197 if m.numSuccesses > 0 { 198 m.numSuccesses-- 199 return m.fakeS3.UploadPartWithContext(ctx, input) 200 } 201 return nil, mockAWSError("MalformedXML") 202 } 203 204 func TestAWSTablePersisterDividePlan(t *testing.T) { 205 assert := assert.New(t) 206 minPartSize, maxPartSize := uint64(16), uint64(32) 207 tooSmall := bytesToChunkSource(t, []byte("a")) 208 justRight := bytesToChunkSource(t, []byte("123456789"), []byte("abcdefghi")) 209 bigUns := [][]byte{make([]byte, maxPartSize-1), make([]byte, maxPartSize-1)} 210 for _, b := range bigUns { 211 rand.Read(b) 212 } 213 tooBig := bytesToChunkSource(t, bigUns...) 214 215 sources := chunkSources{justRight, tooBig, tooSmall} 216 defer func() { 217 for _, s := range sources { 218 s.close() 219 } 220 }() 221 plan, err := planRangeCopyConjoin(sources, &Stats{}) 222 require.NoError(t, err) 223 copies, manuals, _, err := dividePlan(context.Background(), plan, minPartSize, maxPartSize) 224 require.NoError(t, err) 225 226 perTableDataSize := map[string]int64{} 227 for _, c := range copies { 228 assert.True(minPartSize <= uint64(c.srcLen)) 229 assert.True(uint64(c.srcLen) <= maxPartSize) 230 totalSize := perTableDataSize[c.name] 231 totalSize += c.srcLen 232 perTableDataSize[c.name] = totalSize 233 } 234 assert.Len(perTableDataSize, 2) 235 assert.Contains(perTableDataSize, justRight.hash().String()) 236 assert.Contains(perTableDataSize, tooBig.hash().String()) 237 ti, err := justRight.index() 238 require.NoError(t, err) 239 assert.EqualValues(calcChunkRangeSize(ti), perTableDataSize[justRight.hash().String()]) 240 ti, err = tooBig.index() 241 require.NoError(t, err) 242 assert.EqualValues(calcChunkRangeSize(ti), perTableDataSize[tooBig.hash().String()]) 243 244 assert.Len(manuals, 1) 245 ti, err = tooSmall.index() 246 require.NoError(t, err) 247 assert.EqualValues(calcChunkRangeSize(ti), manuals[0].end-manuals[0].start) 248 } 249 250 func TestAWSTablePersisterCalcPartSizes(t *testing.T) { 251 assert := assert.New(t) 252 min, max := uint64(8*1<<10), uint64(1+(16*1<<10)) 253 254 testPartSizes := func(dataLen uint64) { 255 lengths := splitOnMaxSize(dataLen, max) 256 var sum int64 257 for _, l := range lengths { 258 assert.True(uint64(l) >= min) 259 assert.True(uint64(l) <= max) 260 sum += l 261 } 262 assert.EqualValues(dataLen, sum) 263 } 264 265 testPartSizes(1 << 20) 266 testPartSizes(max + 1) 267 testPartSizes(10*max - 1) 268 testPartSizes(max + max/2) 269 } 270 271 func TestAWSTablePersisterConjoinAll(t *testing.T) { 272 ctx := context.Background() 273 const sz5mb = 1 << 20 * 5 274 targetPartSize := uint64(sz5mb) 275 minPartSize, maxPartSize := targetPartSize, 5*targetPartSize 276 277 rl := make(chan struct{}, 8) 278 defer close(rl) 279 280 newPersister := func(s3svc s3iface.S3API) awsTablePersister { 281 return awsTablePersister{ 282 s3svc, 283 "bucket", 284 rl, 285 awsLimits{targetPartSize, minPartSize, maxPartSize}, 286 "", 287 &UnlimitedQuotaProvider{}, 288 } 289 } 290 291 var smallChunks [][]byte 292 rnd := rand.New(rand.NewSource(0)) 293 for smallChunkTotal := uint64(0); smallChunkTotal <= uint64(minPartSize); { 294 small := make([]byte, minPartSize/5) 295 rnd.Read(small) 296 src := bytesToChunkSource(t, small) 297 smallChunks = append(smallChunks, small) 298 ti, err := src.index() 299 require.NoError(t, err) 300 smallChunkTotal += calcChunkRangeSize(ti) 301 ti.Close() 302 } 303 304 t.Run("Small", func(t *testing.T) { 305 makeSources := func(s3p awsTablePersister, chunks [][]byte) (sources chunkSources) { 306 for i := 0; i < len(chunks); i++ { 307 mt := newMemTable(uint64(2 * targetPartSize)) 308 mt.addChunk(computeAddr(chunks[i]), chunks[i]) 309 cs, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) 310 require.NoError(t, err) 311 sources = append(sources, cs) 312 } 313 return 314 } 315 316 t.Run("TotalUnderMinSize", func(t *testing.T) { 317 assert := assert.New(t) 318 s3svc := makeFakeS3(t) 319 s3p := newPersister(s3svc) 320 321 chunks := smallChunks[:len(smallChunks)-1] 322 sources := makeSources(s3p, chunks) 323 src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{}) 324 require.NoError(t, err) 325 defer src.close() 326 for _, s := range sources { 327 s.close() 328 } 329 330 if assert.True(mustUint32(src.count()) > 0) { 331 if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) { 332 assertChunksInReader(chunks, r, assert) 333 r.close() 334 } 335 } 336 }) 337 338 t.Run("TotalOverMinSize", func(t *testing.T) { 339 assert := assert.New(t) 340 s3svc := makeFakeS3(t) 341 s3p := newPersister(s3svc) 342 343 sources := makeSources(s3p, smallChunks) 344 src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{}) 345 require.NoError(t, err) 346 defer src.close() 347 for _, s := range sources { 348 s.close() 349 } 350 351 if assert.True(mustUint32(src.count()) > 0) { 352 if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) { 353 assertChunksInReader(smallChunks, r, assert) 354 r.close() 355 } 356 } 357 }) 358 }) 359 360 bigUns1 := [][]byte{make([]byte, maxPartSize-1), make([]byte, maxPartSize-1)} 361 bigUns2 := [][]byte{make([]byte, maxPartSize-1), make([]byte, maxPartSize-1)} 362 for _, bu := range [][][]byte{bigUns1, bigUns2} { 363 for _, b := range bu { 364 rand.Read(b) 365 } 366 } 367 368 t.Run("AllOverMax", func(t *testing.T) { 369 assert := assert.New(t) 370 s3svc := makeFakeS3(t) 371 s3p := newPersister(s3svc) 372 373 // Make 2 chunk sources that each have >maxPartSize chunk data 374 sources := make(chunkSources, 2) 375 for i, bu := range [][][]byte{bigUns1, bigUns2} { 376 mt := newMemTable(uint64(2 * maxPartSize)) 377 for _, b := range bu { 378 mt.addChunk(computeAddr(b), b) 379 } 380 381 var err error 382 sources[i], err = s3p.Persist(context.Background(), mt, nil, &Stats{}) 383 require.NoError(t, err) 384 } 385 src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{}) 386 require.NoError(t, err) 387 defer src.close() 388 for _, s := range sources { 389 s.close() 390 } 391 392 if assert.True(mustUint32(src.count()) > 0) { 393 if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) { 394 assertChunksInReader(bigUns1, r, assert) 395 assertChunksInReader(bigUns2, r, assert) 396 r.close() 397 } 398 } 399 }) 400 401 t.Run("SomeOverMax", func(t *testing.T) { 402 assert := assert.New(t) 403 s3svc := makeFakeS3(t) 404 s3p := newPersister(s3svc) 405 406 // Add one chunk source that has >maxPartSize data 407 mtb := newMemTable(uint64(2 * maxPartSize)) 408 for _, b := range bigUns1 { 409 mtb.addChunk(computeAddr(b), b) 410 } 411 412 // Follow up with a chunk source where minPartSize < data size < maxPartSize 413 medChunks := make([][]byte, 2) 414 mt := newMemTable(uint64(2 * maxPartSize)) 415 for i := range medChunks { 416 medChunks[i] = make([]byte, minPartSize+1) 417 rand.Read(medChunks[i]) 418 mt.addChunk(computeAddr(medChunks[i]), medChunks[i]) 419 } 420 cs1, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) 421 require.NoError(t, err) 422 cs2, err := s3p.Persist(context.Background(), mtb, nil, &Stats{}) 423 require.NoError(t, err) 424 sources := chunkSources{cs1, cs2} 425 426 src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{}) 427 require.NoError(t, err) 428 defer src.close() 429 for _, s := range sources { 430 s.close() 431 } 432 433 if assert.True(mustUint32(src.count()) > 0) { 434 if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) { 435 assertChunksInReader(bigUns1, r, assert) 436 assertChunksInReader(medChunks, r, assert) 437 r.close() 438 } 439 } 440 }) 441 442 t.Run("Mix", func(t *testing.T) { 443 assert := assert.New(t) 444 s3svc := makeFakeS3(t) 445 s3p := newPersister(s3svc) 446 447 // Start with small tables. Since total > minPartSize, will require more than one part to upload. 448 sources := make(chunkSources, len(smallChunks)) 449 for i := 0; i < len(smallChunks); i++ { 450 mt := newMemTable(uint64(2 * targetPartSize)) 451 mt.addChunk(computeAddr(smallChunks[i]), smallChunks[i]) 452 var err error 453 sources[i], err = s3p.Persist(context.Background(), mt, nil, &Stats{}) 454 require.NoError(t, err) 455 } 456 457 // Now, add a table with big chunks that will require more than one upload copy part. 458 mt := newMemTable(uint64(2 * maxPartSize)) 459 for _, b := range bigUns1 { 460 mt.addChunk(computeAddr(b), b) 461 } 462 463 var err error 464 cs, err := s3p.Persist(context.Background(), mt, nil, &Stats{}) 465 require.NoError(t, err) 466 sources = append(sources, cs) 467 468 // Last, some tables that should be directly upload-copyable 469 medChunks := make([][]byte, 2) 470 mt = newMemTable(uint64(2 * maxPartSize)) 471 for i := range medChunks { 472 medChunks[i] = make([]byte, minPartSize+1) 473 rand.Read(medChunks[i]) 474 mt.addChunk(computeAddr(medChunks[i]), medChunks[i]) 475 } 476 477 cs, err = s3p.Persist(context.Background(), mt, nil, &Stats{}) 478 require.NoError(t, err) 479 sources = append(sources, cs) 480 481 src, _, err := s3p.ConjoinAll(context.Background(), sources, &Stats{}) 482 require.NoError(t, err) 483 defer src.close() 484 for _, s := range sources { 485 s.close() 486 } 487 488 if assert.True(mustUint32(src.count()) > 0) { 489 if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) { 490 assertChunksInReader(smallChunks, r, assert) 491 assertChunksInReader(bigUns1, r, assert) 492 assertChunksInReader(medChunks, r, assert) 493 r.close() 494 } 495 } 496 }) 497 } 498 499 func bytesToChunkSource(t *testing.T, bs ...[]byte) chunkSource { 500 ctx := context.Background() 501 sum := 0 502 for _, b := range bs { 503 sum += len(b) 504 } 505 maxSize := maxTableSize(uint64(len(bs)), uint64(sum)) 506 buff := make([]byte, maxSize) 507 tw := newTableWriter(buff, nil) 508 for _, b := range bs { 509 tw.addChunk(computeAddr(b), b) 510 } 511 tableSize, name, err := tw.finish() 512 require.NoError(t, err) 513 data := buff[:tableSize] 514 ti, err := parseTableIndexByCopy(ctx, data, &UnlimitedQuotaProvider{}) 515 require.NoError(t, err) 516 rdr, err := newTableReader(ti, tableReaderAtFromBytes(data), fileBlockSize) 517 require.NoError(t, err) 518 return chunkSourceAdapter{rdr, name} 519 }