github.com/google/trillian-examples@v0.0.0-20240520080811-0d40d35cef0e/clone/internal/download/batch_test.go (about)

     1  // Copyright 2021 Google LLC. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package download
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"strconv"
    22  	"sync/atomic"
    23  	"testing"
    24  	"time"
    25  )
    26  
    27  func getFakeFetch(fetchOnly uint64) func(start uint64, leaves [][]byte) (uint64, error) {
    28  	return func(start uint64, leaves [][]byte) (uint64, error) {
    29  		for i := range leaves {
    30  			if uint64(i) == fetchOnly {
    31  				return uint64(i), nil
    32  			}
    33  			leaves[i] = []byte(strconv.Itoa(int(start) + i))
    34  		}
    35  		return uint64(len(leaves)), nil
    36  	}
    37  }
    38  
    39  type testCase struct {
    40  	name            string
    41  	first, treeSize uint64
    42  	batchSize       uint
    43  	workers         uint
    44  	wantErr         bool
    45  	fakeFetch       func(start uint64, leaves [][]byte) (uint64, error)
    46  }
    47  
    48  func TestFetchWorkerRun(t *testing.T) {
    49  	for _, test := range []testCase{
    50  		{
    51  			name:      "smallest batch",
    52  			first:     0,
    53  			treeSize:  10,
    54  			batchSize: 1,
    55  			fakeFetch: getFakeFetch(1),
    56  		},
    57  		{
    58  			name:      "larger batch",
    59  			first:     0,
    60  			treeSize:  110,
    61  			batchSize: 10,
    62  			fakeFetch: getFakeFetch(10),
    63  		},
    64  		{
    65  			name:      "bigger batch than tree",
    66  			first:     0,
    67  			treeSize:  9,
    68  			batchSize: 10,
    69  			fakeFetch: getFakeFetch(10),
    70  		},
    71  		{
    72  			name:      "batch size non-divisor of range",
    73  			first:     0,
    74  			treeSize:  107,
    75  			batchSize: 10,
    76  			fakeFetch: getFakeFetch(10),
    77  		},
    78  	} {
    79  		t.Run(test.name, func(t *testing.T) {
    80  			wrc := make(chan workerResult)
    81  
    82  			fw := fetchWorker{
    83  				label:      test.name,
    84  				start:      test.first,
    85  				treeSize:   test.treeSize,
    86  				increment:  uint64(test.batchSize),
    87  				count:      test.batchSize,
    88  				out:        wrc,
    89  				batchFetch: test.fakeFetch,
    90  			}
    91  
    92  			go fw.run(context.Background())
    93  
    94  			var seen, i int
    95  			for r := range wrc {
    96  				if r.err != nil {
    97  					t.Fatal(r.err)
    98  				}
    99  				if got, want := r.start, uint64(i*int(test.batchSize)); got != want {
   100  					t.Errorf("%d got != want (%d != %d)", i, got, want)
   101  				}
   102  				seen = seen + len(r.leaves)
   103  				i++
   104  			}
   105  			if seen != int(test.treeSize) {
   106  				t.Errorf("expected to see %d leaves but saw %d", test.treeSize, seen)
   107  			}
   108  		})
   109  	}
   110  }
   111  
   112  func TestBulk(t *testing.T) {
   113  	for _, test := range []testCase{
   114  		{
   115  			name:      "smallest batch",
   116  			first:     0,
   117  			treeSize:  10,
   118  			batchSize: 1,
   119  			workers:   1,
   120  			fakeFetch: getFakeFetch(1),
   121  		},
   122  		{
   123  			name:      "larger batch",
   124  			first:     0,
   125  			treeSize:  110,
   126  			batchSize: 10,
   127  			workers:   4,
   128  			fakeFetch: getFakeFetch(10),
   129  		},
   130  		{
   131  			name:      "bigger batch than tree",
   132  			first:     0,
   133  			treeSize:  9,
   134  			batchSize: 10,
   135  			workers:   1,
   136  			fakeFetch: getFakeFetch(10),
   137  		},
   138  		{
   139  			name:      "batch size equals tree size",
   140  			first:     0,
   141  			treeSize:  10,
   142  			batchSize: 10,
   143  			workers:   1,
   144  			fakeFetch: getFakeFetch(10),
   145  		},
   146  		{
   147  			name:      "batch size non-divisor of range",
   148  			first:     0,
   149  			treeSize:  107,
   150  			batchSize: 10,
   151  			workers:   4,
   152  			fakeFetch: getFakeFetch(10),
   153  		},
   154  	} {
   155  		t.Run(test.name, func(t *testing.T) {
   156  			brc := make(chan BulkResult)
   157  
   158  			go Bulk(context.Background(), test.first, test.treeSize, test.fakeFetch, test.workers, test.batchSize, brc)
   159  
   160  			i := 0
   161  			for br := range brc {
   162  				if br.Err != nil {
   163  					t.Fatal(br.Err)
   164  				}
   165  				if got, want := string(br.Leaf), strconv.Itoa(i); got != want {
   166  					t.Errorf("%d got != want (%q != %q)", i, got, want)
   167  				}
   168  				i++
   169  			}
   170  			if i != int(test.treeSize) {
   171  				t.Errorf("expected %d leaves, got %d", test.treeSize, i)
   172  			}
   173  		})
   174  	}
   175  }
   176  
   177  func TestBulkCancelled(t *testing.T) {
   178  	brc := make(chan BulkResult, 10)
   179  	var first uint64
   180  	var treeSize uint64 = 1000
   181  	var workers uint = 4
   182  	var batchSize uint = 10
   183  
   184  	fakeFetch := getFakeFetch(10)
   185  	ctx, cancel := context.WithCancel(context.Background())
   186  	defer cancel()
   187  
   188  	go Bulk(ctx, first, treeSize, fakeFetch, workers, batchSize, brc)
   189  
   190  	seen := 0
   191  	for i := 0; i < int(treeSize); i++ {
   192  		br := <-brc
   193  		if br.Err != nil {
   194  			break
   195  		}
   196  		seen++
   197  		if seen == 10 {
   198  			cancel()
   199  		}
   200  	}
   201  	if seen == int(treeSize) {
   202  		t.Error("Expected cancellation to prevent all leaves being read")
   203  	}
   204  }
   205  
   206  func TestBulkIncomplete(t *testing.T) {
   207  	for _, test := range []testCase{
   208  		{
   209  			name:      "incomplete first batch",
   210  			first:     0,
   211  			treeSize:  100,
   212  			batchSize: 10,
   213  			workers:   4,
   214  			wantErr:   false,
   215  			fakeFetch: func(start uint64, leaves [][]byte) (uint64, error) {
   216  				fetched := uint64(len(leaves))
   217  				for i := range leaves {
   218  					leaves[i] = []byte(strconv.Itoa(int(start) + i))
   219  					if start == 0 && i == 4 {
   220  						fetched = 5
   221  						break
   222  					}
   223  				}
   224  				return fetched, nil
   225  			},
   226  		},
   227  		{
   228  			name:      "incomplete last batch",
   229  			first:     0,
   230  			treeSize:  100,
   231  			batchSize: 10,
   232  			workers:   4,
   233  			wantErr:   true,
   234  			fakeFetch: func(start uint64, leaves [][]byte) (uint64, error) {
   235  				fetched := uint64(len(leaves))
   236  				for i := range leaves {
   237  					leaves[i] = []byte(strconv.Itoa(int(start) + i))
   238  					if start == 90 && i == 4 {
   239  						fetched = 5
   240  						break
   241  					}
   242  				}
   243  				return fetched, nil
   244  			},
   245  		},
   246  		{
   247  			name:      "incomplete middle batch",
   248  			first:     0,
   249  			treeSize:  100,
   250  			batchSize: 10,
   251  			workers:   4,
   252  			wantErr:   true,
   253  			fakeFetch: func(start uint64, leaves [][]byte) (uint64, error) {
   254  				fetched := uint64(len(leaves))
   255  				for i := range leaves {
   256  					leaves[i] = []byte(strconv.Itoa(int(start) + i))
   257  					if start == 50 && i == 4 {
   258  						fetched = 5
   259  						break
   260  					}
   261  				}
   262  				return fetched, nil
   263  			},
   264  		},
   265  	} {
   266  		t.Run(test.name, func(t *testing.T) {
   267  			brc := make(chan BulkResult)
   268  			go Bulk(context.Background(), test.first, test.treeSize, test.fakeFetch, test.workers, test.batchSize, brc)
   269  
   270  			i := 0
   271  			var err error
   272  			for br := range brc {
   273  				if br.Err != nil && !test.wantErr {
   274  					t.Fatal(br.Err)
   275  				}
   276  				if br.Err != nil && test.wantErr {
   277  					err = br.Err
   278  				}
   279  				if got, want := string(br.Leaf), strconv.Itoa(i); got != want && err == nil {
   280  					t.Fatalf("%d got != want (%q != %q)", i, got, want)
   281  				}
   282  				i++
   283  			}
   284  			if err == nil && test.wantErr {
   285  				t.Errorf("expected error, got none")
   286  			}
   287  			if err != nil && !test.wantErr {
   288  				t.Errorf("unexpected error: %v", err)
   289  			}
   290  			if i != int(test.treeSize) && !test.wantErr {
   291  				t.Errorf("expected %d leaves, got %d", test.treeSize, i)
   292  			}
   293  		})
   294  	}
   295  }
   296  
   297  func BenchmarkBulk(b *testing.B) {
   298  	for _, test := range []struct {
   299  		workers    uint
   300  		batchSize  uint
   301  		fetchDelay time.Duration
   302  		quota      int64
   303  	}{
   304  		{
   305  			workers:    20,
   306  			batchSize:  10,
   307  			fetchDelay: 50 * time.Microsecond,
   308  		},
   309  		{
   310  			workers:    20,
   311  			batchSize:  1,
   312  			fetchDelay: 50 * time.Microsecond,
   313  		},
   314  		{
   315  			workers:    1,
   316  			batchSize:  1,
   317  			fetchDelay: 50 * time.Microsecond,
   318  		},
   319  		{
   320  			workers:    1,
   321  			batchSize:  200,
   322  			fetchDelay: 50 * time.Microsecond,
   323  		},
   324  		{
   325  			workers:    20,
   326  			batchSize:  10,
   327  			fetchDelay: 50 * time.Microsecond,
   328  			quota:      1000,
   329  		},
   330  	} {
   331  		b.Run(fmt.Sprintf("w=%d,b=%d,delay=%s,q=%d", test.workers, test.batchSize, test.fetchDelay, test.quota), func(b *testing.B) {
   332  			brc := make(chan BulkResult, 10)
   333  			var first uint64
   334  
   335  			ctx, cancel := context.WithCancel(context.Background())
   336  			defer cancel()
   337  
   338  			take := func(n int) error { return nil }
   339  			if test.quota > 0 {
   340  				th := throttle{
   341  					quota:  test.quota,
   342  					refill: test.quota,
   343  				}
   344  				go th.startRefillLoop(ctx)
   345  				take = th.take
   346  			}
   347  
   348  			fakeFetch := func(start uint64, leaves [][]byte) (uint64, error) {
   349  				time.Sleep(test.fetchDelay)
   350  				if err := take(len(leaves)); err != nil {
   351  					return 0, err
   352  				}
   353  				for i := range leaves {
   354  					// Allocate a non-trivial amount of memory for the leaf.
   355  					leaf := make([]byte, 1024)
   356  					leaves[i] = leaf
   357  				}
   358  				return uint64(len(leaves)), nil
   359  			}
   360  
   361  			const consumeSize = 1000
   362  			go Bulk(ctx, first, uint64(b.N*consumeSize), fakeFetch, test.workers, test.batchSize, brc)
   363  
   364  			for n := 0; n < b.N; n++ {
   365  				for i := 0; i < consumeSize; i++ {
   366  					br := <-brc
   367  					if br.Err != nil {
   368  						b.Fatal(br.Err)
   369  					}
   370  				}
   371  			}
   372  		})
   373  	}
   374  }
   375  
   376  type throttle struct {
   377  	quota  int64
   378  	refill int64
   379  }
   380  
   381  func (t *throttle) take(n int) error {
   382  	if atomic.AddInt64(&t.quota, int64(n*-1)) > 0 {
   383  		return nil
   384  	}
   385  	return errors.New("out of quota")
   386  }
   387  
   388  func (t *throttle) startRefillLoop(ctx context.Context) {
   389  	tik := time.NewTicker(10 * time.Millisecond)
   390  	for {
   391  		select {
   392  		case <-ctx.Done():
   393  			return
   394  		case <-tik.C:
   395  			atomic.StoreInt64(&t.quota, t.refill)
   396  		}
   397  	}
   398  }