github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/datas/pull/pull_chunk_tracker_test.go (about)

     1  // Copyright 2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package pull
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"testing"
    21  
    22  	"github.com/dolthub/dolt/go/store/hash"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  )
    26  
    27  func TestPullChunkTracker(t *testing.T) {
    28  	t.Run("Empty", func(t *testing.T) {
    29  		tracker := NewPullChunkTracker(context.Background(), make(hash.HashSet), TrackerConfig{
    30  			BatchSize: 64 * 1024,
    31  			HasManyer: nil,
    32  		})
    33  		hs, ok, err := tracker.GetChunksToFetch()
    34  		assert.Len(t, hs, 0)
    35  		assert.False(t, ok)
    36  		assert.NoError(t, err)
    37  		tracker.Close()
    38  	})
    39  
    40  	t.Run("HasAllInitial", func(t *testing.T) {
    41  		hs := make(hash.HashSet)
    42  		for i := byte(0); i < byte(10); i++ {
    43  			var h hash.Hash
    44  			h[0] = i
    45  			hs.Insert(h)
    46  		}
    47  		tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
    48  			BatchSize: 64 * 1024,
    49  			HasManyer: hasAllHaser{},
    50  		})
    51  		hs, ok, err := tracker.GetChunksToFetch()
    52  		assert.Len(t, hs, 0)
    53  		assert.False(t, ok)
    54  		assert.NoError(t, err)
    55  		tracker.Close()
    56  	})
    57  
    58  	t.Run("HasNoneInitial", func(t *testing.T) {
    59  		hs := make(hash.HashSet)
    60  		for i := byte(1); i <= byte(10); i++ {
    61  			var h hash.Hash
    62  			h[0] = i
    63  			hs.Insert(h)
    64  		}
    65  		tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
    66  			BatchSize: 64 * 1024,
    67  			HasManyer: hasNoneHaser{},
    68  		})
    69  		hs, ok, err := tracker.GetChunksToFetch()
    70  		assert.Len(t, hs, 10)
    71  		assert.True(t, ok)
    72  		assert.NoError(t, err)
    73  		for _ = range hs {
    74  			tracker.TickProcessed()
    75  		}
    76  		hs, ok, err = tracker.GetChunksToFetch()
    77  		assert.Len(t, hs, 0)
    78  		assert.False(t, ok)
    79  		assert.NoError(t, err)
    80  
    81  		for i := byte(1); i <= byte(10); i++ {
    82  			var h hash.Hash
    83  			h[1] = i
    84  			tracker.Seen(h)
    85  		}
    86  
    87  		cnt := 0
    88  		for {
    89  			hs, ok, err := tracker.GetChunksToFetch()
    90  			assert.NoError(t, err)
    91  			if !ok {
    92  				assert.Equal(t, 10, cnt)
    93  				break
    94  			}
    95  			cnt += len(hs)
    96  			for _ = range hs {
    97  				tracker.TickProcessed()
    98  			}
    99  		}
   100  
   101  		tracker.Close()
   102  	})
   103  
   104  	t.Run("HasManyError", func(t *testing.T) {
   105  		hs := make(hash.HashSet)
   106  		for i := byte(0); i < byte(10); i++ {
   107  			var h hash.Hash
   108  			h[0] = i
   109  			hs.Insert(h)
   110  		}
   111  		tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
   112  			BatchSize: 64 * 1024,
   113  			HasManyer: errHaser{},
   114  		})
   115  		_, _, err := tracker.GetChunksToFetch()
   116  		assert.Error(t, err)
   117  		tracker.Close()
   118  	})
   119  
   120  	t.Run("InitialAreSeen", func(t *testing.T) {
   121  		hs := make(hash.HashSet)
   122  		for i := byte(0); i < byte(10); i++ {
   123  			var h hash.Hash
   124  			h[0] = i
   125  			hs.Insert(h)
   126  		}
   127  		tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
   128  			BatchSize: 64 * 1024,
   129  			HasManyer: hasNoneHaser{},
   130  		})
   131  		hs, ok, err := tracker.GetChunksToFetch()
   132  		assert.Len(t, hs, 10)
   133  		assert.True(t, ok)
   134  		assert.NoError(t, err)
   135  
   136  		for i := byte(0); i < byte(10); i++ {
   137  			var h hash.Hash
   138  			h[0] = i
   139  			tracker.Seen(h)
   140  		}
   141  		for _ = range hs {
   142  			tracker.TickProcessed()
   143  		}
   144  
   145  		hs, ok, err = tracker.GetChunksToFetch()
   146  		assert.Len(t, hs, 0)
   147  		assert.False(t, ok)
   148  		assert.NoError(t, err)
   149  
   150  		tracker.Close()
   151  	})
   152  
   153  	t.Run("StaticHaser", func(t *testing.T) {
   154  		haser := staticHaser{make(hash.HashSet)}
   155  		initial := make([]hash.Hash, 4)
   156  		initial[0][0] = 1
   157  		initial[1][0] = 2
   158  		initial[2][0] = 1
   159  		initial[2][1] = 1
   160  		initial[3][0] = 1
   161  		initial[3][1] = 2
   162  		haser.has.Insert(initial[0])
   163  		haser.has.Insert(initial[1])
   164  		haser.has.Insert(initial[2])
   165  		haser.has.Insert(initial[3])
   166  
   167  		hs := make(hash.HashSet)
   168  		// Start with 1 - 5
   169  		for i := byte(1); i <= byte(5); i++ {
   170  			var h hash.Hash
   171  			h[0] = i
   172  			hs.Insert(h)
   173  		}
   174  		tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
   175  			BatchSize: 64 * 1024,
   176  			HasManyer: haser,
   177  		})
   178  
   179  		// Should get back 03, 04, 05
   180  		hs, ok, err := tracker.GetChunksToFetch()
   181  		assert.Len(t, hs, 3)
   182  		assert.True(t, ok)
   183  		assert.NoError(t, err)
   184  		for _ = range hs {
   185  			tracker.TickProcessed()
   186  		}
   187  
   188  		for i := byte(1); i <= byte(10); i++ {
   189  			var h hash.Hash
   190  			h[0] = 1
   191  			h[1] = i
   192  			tracker.Seen(h)
   193  		}
   194  
   195  		// Should get back 13, 14, 15, 16, 17, 18, 19, 1(10).
   196  		cnt := 0
   197  		for {
   198  			hs, ok, err := tracker.GetChunksToFetch()
   199  			assert.NoError(t, err)
   200  			if !ok {
   201  				break
   202  			}
   203  			cnt += len(hs)
   204  			for _ = range hs {
   205  				tracker.TickProcessed()
   206  			}
   207  		}
   208  		assert.Equal(t, 8, cnt)
   209  
   210  		tracker.Close()
   211  	})
   212  
   213  	t.Run("SmallBatches", func(t *testing.T) {
   214  		haser := staticHaser{make(hash.HashSet)}
   215  		initial := make([]hash.Hash, 4)
   216  		initial[0][0] = 1
   217  		initial[1][0] = 2
   218  		initial[2][0] = 1
   219  		initial[2][1] = 1
   220  		initial[3][0] = 1
   221  		initial[3][1] = 2
   222  		haser.has.Insert(initial[0])
   223  		haser.has.Insert(initial[1])
   224  		haser.has.Insert(initial[2])
   225  		haser.has.Insert(initial[3])
   226  
   227  		hs := make(hash.HashSet)
   228  		// Start with 1 - 5
   229  		for i := byte(1); i <= byte(5); i++ {
   230  			var h hash.Hash
   231  			h[0] = i
   232  			hs.Insert(h)
   233  		}
   234  		tracker := NewPullChunkTracker(context.Background(), hs, TrackerConfig{
   235  			BatchSize: 1,
   236  			HasManyer: haser,
   237  		})
   238  
   239  		// First call doesn't actually respect batch size.
   240  		hs, ok, err := tracker.GetChunksToFetch()
   241  		assert.Len(t, hs, 3)
   242  		assert.True(t, ok)
   243  		assert.NoError(t, err)
   244  		for _ = range hs {
   245  			tracker.TickProcessed()
   246  		}
   247  
   248  		for i := byte(1); i <= byte(10); i++ {
   249  			var h hash.Hash
   250  			h[0] = 1
   251  			h[1] = i
   252  			tracker.Seen(h)
   253  		}
   254  
   255  		// Should get back 13, 14, 15, 16, 17, 18, 19, 1(10); one at a time.
   256  		cnt := 0
   257  		for {
   258  			hs, ok, err := tracker.GetChunksToFetch()
   259  			assert.NoError(t, err)
   260  			if !ok {
   261  				break
   262  			}
   263  			assert.Len(t, hs, 1)
   264  			cnt += len(hs)
   265  			for _ = range hs {
   266  				tracker.TickProcessed()
   267  			}
   268  		}
   269  		assert.Equal(t, 8, cnt)
   270  
   271  		tracker.Close()
   272  	})
   273  }
   274  
   275  type hasAllHaser struct {
   276  }
   277  
   278  func (hasAllHaser) HasMany(context.Context, hash.HashSet) (hash.HashSet, error) {
   279  	return make(hash.HashSet), nil
   280  }
   281  
   282  type hasNoneHaser struct {
   283  }
   284  
   285  func (hasNoneHaser) HasMany(ctx context.Context, hs hash.HashSet) (hash.HashSet, error) {
   286  	return hs, nil
   287  }
   288  
   289  type staticHaser struct {
   290  	has hash.HashSet
   291  }
   292  
   293  func (s staticHaser) HasMany(ctx context.Context, query hash.HashSet) (hash.HashSet, error) {
   294  	ret := make(hash.HashSet)
   295  	for h := range query {
   296  		if !s.has.Has(h) {
   297  			ret.Insert(h)
   298  		}
   299  	}
   300  	return ret, nil
   301  }
   302  
   303  type errHaser struct {
   304  }
   305  
   306  func (errHaser) HasMany(ctx context.Context, hs hash.HashSet) (hash.HashSet, error) {
   307  	return nil, errors.New("always throws an error")
   308  }