github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/datas/pull/pull_chunk_fetcher_test.go (about)

     1  // Copyright 2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package pull
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"sync"
    22  	"testing"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  
    26  	"github.com/dolthub/dolt/go/store/hash"
    27  	"github.com/dolthub/dolt/go/store/nbs"
    28  )
    29  
    30  func TestPullChunkFetcher(t *testing.T) {
    31  	t.Run("ImmediateCloseSend", func(t *testing.T) {
    32  		f := NewPullChunkFetcher(context.Background(), emptyGetManyer{})
    33  		assert.NoError(t, f.CloseSend())
    34  		_, err := f.Recv(context.Background())
    35  		assert.ErrorIs(t, err, io.EOF)
    36  		assert.NoError(t, f.Close())
    37  	})
    38  	t.Run("CanceledGetCtx", func(t *testing.T) {
    39  		ctx, c := context.WithCancel(context.Background())
    40  		gm := blockingGetManyer{make(chan struct{})}
    41  		f := NewPullChunkFetcher(context.Background(), gm)
    42  		hs := make(hash.HashSet)
    43  		var h hash.Hash
    44  		hs.Insert(h)
    45  		err := f.Get(ctx, hs)
    46  		assert.NoError(t, err)
    47  		c()
    48  		err = f.Get(ctx, hs)
    49  		assert.Error(t, err)
    50  		close(gm.block)
    51  		assert.NoError(t, f.Close())
    52  	})
    53  	t.Run("CanceledRecvCtx", func(t *testing.T) {
    54  		ctx, c := context.WithCancel(context.Background())
    55  		f := NewPullChunkFetcher(context.Background(), emptyGetManyer{})
    56  		c()
    57  		_, err := f.Recv(ctx)
    58  		assert.Error(t, err)
    59  		assert.NoError(t, f.Close())
    60  	})
    61  	t.Run("ReturnsDelieveredChunk", func(t *testing.T) {
    62  		var gm deliveringGetManyer
    63  		gm.C.FullCompressedChunk = make([]byte, 1024)
    64  		f := NewPullChunkFetcher(context.Background(), gm)
    65  		hs := make(hash.HashSet)
    66  		hs.Insert(gm.C.H)
    67  		var wg sync.WaitGroup
    68  		wg.Add(1)
    69  		go func() {
    70  			defer wg.Done()
    71  			cmp, err := f.Recv(context.Background())
    72  			assert.NoError(t, err)
    73  			assert.Equal(t, cmp.H, gm.C.H)
    74  			assert.Equal(t, cmp.FullCompressedChunk, gm.C.FullCompressedChunk)
    75  			_, err = f.Recv(context.Background())
    76  			assert.ErrorIs(t, err, io.EOF)
    77  			assert.NoError(t, f.Close())
    78  		}()
    79  		err := f.Get(context.Background(), hs)
    80  		assert.NoError(t, err)
    81  		assert.NoError(t, f.CloseSend())
    82  		wg.Wait()
    83  	})
    84  	t.Run("ReturnsEmptyCompressedChunk", func(t *testing.T) {
    85  		f := NewPullChunkFetcher(context.Background(), emptyGetManyer{})
    86  		hs := make(hash.HashSet)
    87  		var h hash.Hash
    88  		hs.Insert(h)
    89  		var wg sync.WaitGroup
    90  		wg.Add(1)
    91  		go func() {
    92  			defer wg.Done()
    93  			cmp, err := f.Recv(context.Background())
    94  			assert.NoError(t, err)
    95  			assert.Equal(t, cmp.H, h)
    96  			assert.Nil(t, cmp.FullCompressedChunk)
    97  			_, err = f.Recv(context.Background())
    98  			assert.ErrorIs(t, err, io.EOF)
    99  			assert.NoError(t, f.Close())
   100  		}()
   101  		err := f.Get(context.Background(), hs)
   102  		assert.NoError(t, err)
   103  		assert.NoError(t, f.CloseSend())
   104  		wg.Wait()
   105  	})
   106  	t.Run("ErrorGetManyer", func(t *testing.T) {
   107  		f := NewPullChunkFetcher(context.Background(), errorGetManyer{})
   108  		hs := make(hash.HashSet)
   109  		var h hash.Hash
   110  		hs.Insert(h)
   111  		var wg sync.WaitGroup
   112  		wg.Add(1)
   113  		go func() {
   114  			defer wg.Done()
   115  			_, err := f.Recv(context.Background())
   116  			assert.Error(t, err)
   117  			err = f.Close()
   118  			assert.Error(t, err)
   119  		}()
   120  		err := f.Get(context.Background(), hs)
   121  		assert.NoError(t, err)
   122  		err = f.Get(context.Background(), hs)
   123  		assert.Error(t, err)
   124  		wg.Wait()
   125  	})
   126  	t.Run("ClosedFetcherErrorsGet", func(t *testing.T) {
   127  		f := NewPullChunkFetcher(context.Background(), emptyGetManyer{})
   128  		assert.NoError(t, f.Close())
   129  		hs := make(hash.HashSet)
   130  		var h hash.Hash
   131  		hs.Insert(h)
   132  		assert.Error(t, f.Get(context.Background(), hs))
   133  	})
   134  }
   135  
   136  type emptyGetManyer struct {
   137  }
   138  
   139  func (emptyGetManyer) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, nbs.CompressedChunk)) error {
   140  	return nil
   141  }
   142  
   143  type deliveringGetManyer struct {
   144  	C nbs.CompressedChunk
   145  }
   146  
   147  func (d deliveringGetManyer) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, nbs.CompressedChunk)) error {
   148  	for _ = range hashes {
   149  		found(ctx, d.C)
   150  	}
   151  	return nil
   152  }
   153  
   154  type blockingGetManyer struct {
   155  	block chan struct{}
   156  }
   157  
   158  func (b blockingGetManyer) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, nbs.CompressedChunk)) error {
   159  	<-b.block
   160  	return nil
   161  }
   162  
   163  type errorGetManyer struct {
   164  }
   165  
   166  var getManyerErr = fmt.Errorf("always return an error")
   167  
   168  func (errorGetManyer) GetManyCompressed(ctx context.Context, hashes hash.HashSet, found func(context.Context, nbs.CompressedChunk)) error {
   169  	return getManyerErr
   170  }