github.com/grailbio/base@v0.0.11/digest/digestrw_test.go (about)

     1  // Copyright 2017 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package digest_test
     6  
     7  import (
     8  	"context"
     9  	"crypto"
    10  	_ "crypto/sha256" // Required for the SHA256 constant.
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"math/rand"
    15  	"os"
    16  	"path"
    17  	"strings"
    18  	"sync"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/aws/aws-sdk-go/aws"
    23  	"github.com/aws/aws-sdk-go/service/s3"
    24  	"github.com/aws/aws-sdk-go/service/s3/s3manager"
    25  	"github.com/grailbio/base/digest"
    26  	"github.com/grailbio/base/traverse"
    27  	"github.com/grailbio/testutil"
    28  	"github.com/grailbio/testutil/s3test"
    29  )
    30  
    31  func min(a, b int64) int64 {
    32  	if a < b {
    33  		return a
    34  	}
    35  	return b
    36  }
    37  
    38  func TestDigestReader(t *testing.T) {
    39  	digester := digest.Digester(crypto.SHA256)
    40  
    41  	dataSize := int64(950)
    42  	segmentSize := int64(100)
    43  
    44  	for _, test := range []struct {
    45  		reader io.Reader
    46  		order  []int64
    47  	}{
    48  		{
    49  			&testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0},
    50  			[]int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
    51  		},
    52  		{
    53  			&testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0},
    54  			[]int64{1, 0, 3, 2, 5, 4, 7, 6, 9, 8},
    55  		},
    56  		{
    57  			&testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0},
    58  			[]int64{9, 8, 7, 6, 5, 4, 3, 2, 1, 0},
    59  		},
    60  	} {
    61  		dra := digester.NewReader(test.reader)
    62  		readerAt, ok := dra.(io.ReaderAt)
    63  		if !ok {
    64  			t.Fatal("reader does not support ReaderAt")
    65  		}
    66  
    67  		err := traverse.Each(len(test.order), func(jobIdx int) error {
    68  			time.Sleep(10 * time.Duration(jobIdx) * time.Millisecond)
    69  			index := test.order[jobIdx]
    70  			size := min(segmentSize, (dataSize-index)*segmentSize)
    71  			d := make([]byte, size)
    72  			_, err := readerAt.ReadAt(d, index*int64(segmentSize))
    73  			return err
    74  		})
    75  		if err != nil {
    76  			t.Fatal(err)
    77  		}
    78  
    79  		actual, err := dra.Digest()
    80  		if err != nil {
    81  			t.Fatal(err)
    82  		}
    83  
    84  		writer := digester.NewWriter()
    85  		content := &testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0}
    86  		if _, err := io.Copy(writer, content); err != nil {
    87  			t.Fatal(err)
    88  		}
    89  		expected := writer.Digest()
    90  
    91  		if actual != expected {
    92  			t.Fatalf("digest mismatch: %s vs %s", actual, expected)
    93  		}
    94  	}
    95  }
    96  
    97  func TestDigestWriter(t *testing.T) {
    98  	td, err := ioutil.TempDir("", "grail_cache_test")
    99  	if err != nil {
   100  		t.Fatal(err)
   101  	}
   102  	defer os.RemoveAll(td)
   103  
   104  	digester := digest.Digester(crypto.SHA256)
   105  
   106  	tests := [][]int{
   107  		{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
   108  		{9, 8, 7, 6, 5, 4, 3, 2, 1, 0},
   109  	}
   110  	for i := 0; i < 50; i++ {
   111  		tests = append(tests, rand.Perm(10))
   112  	}
   113  
   114  	for _, test := range tests {
   115  		testFile := path.Join(td, "testfile")
   116  
   117  		output, err := os.Create(testFile)
   118  		if err != nil {
   119  			t.Fatal(err)
   120  		}
   121  
   122  		dwa := digester.NewWriterAt(context.Background(), output)
   123  
   124  		err = traverse.Each(len(test), func(jobIdx int) error {
   125  			time.Sleep(5 * time.Duration(jobIdx) * time.Millisecond)
   126  			i := test[jobIdx]
   127  			segmentString := strings.Repeat(fmt.Sprintf("%c", 'a'+i), 100)
   128  			offset := int64(i * len(segmentString))
   129  
   130  			_, e := dwa.WriteAt([]byte(segmentString), offset)
   131  			return e
   132  		})
   133  		output.Close()
   134  		if err != nil {
   135  			t.Fatal(err)
   136  		}
   137  
   138  		expected, err := dwa.Digest()
   139  		if err != nil {
   140  			t.Fatal(err)
   141  		}
   142  
   143  		input, err := os.Open(testFile)
   144  		if err != nil {
   145  			t.Fatal(err)
   146  		}
   147  
   148  		w := digester.NewWriter()
   149  		io.Copy(w, input)
   150  
   151  		if got := w.Digest(); expected != got {
   152  			t.Fatalf("expected: %v, got: %v", expected, got)
   153  		}
   154  
   155  		input.Close()
   156  	}
   157  }
   158  
   159  func TestDigestWriterContext(t *testing.T) {
   160  	f, err := ioutil.TempFile("", "")
   161  	if err != nil {
   162  		t.Fatal(err)
   163  	}
   164  	defer os.Remove(f.Name())
   165  	digester := digest.Digester(crypto.SHA256)
   166  	ctx, cancel := context.WithCancel(context.Background())
   167  	w := digester.NewWriterAt(ctx, f)
   168  	_, err = w.WriteAt([]byte{1, 2, 3}, 0)
   169  	if err != nil {
   170  		t.Fatal(err)
   171  	}
   172  	_, err = w.WriteAt([]byte{4, 5, 6}, 3)
   173  	if err != nil {
   174  		t.Fatal(err)
   175  	}
   176  	// By now we know the looper is up and running.
   177  	var wg sync.WaitGroup
   178  	wg.Add(10)
   179  	for i := int64(0); i < 10; i++ {
   180  		go func(i int64) {
   181  			_, err := w.WriteAt([]byte{1}, 100+i)
   182  			if got, want := err, ctx.Err(); got != want {
   183  				t.Errorf("got %v, want %v", got, want)
   184  			}
   185  			wg.Done()
   186  		}(i)
   187  	}
   188  	cancel()
   189  	wg.Wait()
   190  }
   191  
   192  func TestS3ManagerUpload(t *testing.T) {
   193  	client := s3test.NewClient(t, "test-bucket")
   194  
   195  	size := int64(93384620) // Completely random number.
   196  
   197  	digester := digest.Digester(crypto.SHA256)
   198  	contentAt := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0}
   199  	client.SetFileContentAt("test/test/test", contentAt, "fakesha")
   200  	reader := digester.NewReader(contentAt)
   201  
   202  	// The relationship between size and PartSize is:
   203  	//    - size/PartSize > 10 to utilize multiple parallel uploads.
   204  	//    - size%PartSize != 0 so the last part is a partial upload.
   205  	//    - size is small enough that the unittest runs quickly.
   206  	//    - PartSize is the minimum allowed.
   207  	uploader := s3manager.NewUploaderWithClient(client, func(d *s3manager.Uploader) { d.Concurrency = 30; d.PartSize = 5242880 })
   208  
   209  	input := &s3manager.UploadInput{
   210  		Bucket: aws.String("test-bucket"),
   211  		Key:    aws.String("test/test/test"),
   212  		Body:   reader,
   213  	}
   214  
   215  	// Perform an upload.
   216  	_, err := uploader.UploadWithContext(context.Background(), input, func(*s3manager.Uploader) {})
   217  	if err != nil {
   218  		t.Fatal(err)
   219  	}
   220  
   221  	got, err := reader.Digest()
   222  	if err != nil {
   223  		t.Fatal(err)
   224  	}
   225  
   226  	dw := digester.NewWriter()
   227  	content := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0}
   228  	if _, err := io.Copy(dw, content); err != nil {
   229  		t.Fatal(err)
   230  	}
   231  	expected := dw.Digest()
   232  
   233  	if got != expected {
   234  		t.Fatalf("digest mismatch, expected %s, got %s", expected, got)
   235  	}
   236  }
   237  
   238  func TestS3ManagerDownload(t *testing.T) {
   239  	client := s3test.NewClient(t, "test-bucket")
   240  	client.NumMaxRetries = 10
   241  
   242  	size := int64(86738922) // Completely random number.
   243  
   244  	digester := digest.Digester(crypto.SHA256)
   245  	contentAt := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0.001}
   246  	client.SetFileContentAt("test/test/test", contentAt, "fakesha")
   247  	writer := digester.NewWriterAt(context.Background(), contentAt)
   248  
   249  	// The relationship between size and PartSize is:
   250  	//    - size/PartSize > 10 to utilize multiple parallel uploads.
   251  	//    - size%PartSize != 0 so the last part is a partial upload.
   252  	//    - size is small enough that the unittest runs quickly.
   253  	//    - PartSize is the minimum allowed.
   254  	downloader := s3manager.NewDownloaderWithClient(client, func(d *s3manager.Downloader) { d.Concurrency = 30; d.PartSize = 55242880 })
   255  
   256  	params := &s3.GetObjectInput{
   257  		Bucket: aws.String("test-bucket"),
   258  		Key:    aws.String("test/test/test"),
   259  	}
   260  	_, err := downloader.DownloadWithContext(
   261  		context.Background(),
   262  		writer,
   263  		params,
   264  	)
   265  	if err != nil {
   266  		t.Fatal(err)
   267  	}
   268  
   269  	got, err := writer.Digest()
   270  	if err != nil {
   271  		t.Fatal(err)
   272  	}
   273  
   274  	dw := digester.NewWriter()
   275  	content := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0}
   276  	if _, err := io.Copy(dw, content); err != nil {
   277  		t.Fatal(err)
   278  	}
   279  	expected := dw.Digest()
   280  
   281  	if got != expected {
   282  		t.Fatalf("digest mismatch, expected %s, got %s", expected, got)
   283  	}
   284  }