github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/digest/digestrw_test.go (about) 1 // Copyright 2017 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package digest_test 6 7 import ( 8 "context" 9 "crypto" 10 _ "crypto/sha256" // Required for the SHA256 constant. 11 "fmt" 12 "io" 13 "io/ioutil" 14 "math/rand" 15 "os" 16 "path" 17 "strings" 18 "sync" 19 "testing" 20 "time" 21 22 "github.com/aws/aws-sdk-go/aws" 23 "github.com/aws/aws-sdk-go/service/s3" 24 "github.com/aws/aws-sdk-go/service/s3/s3manager" 25 "github.com/Schaudge/grailbase/digest" 26 "github.com/Schaudge/grailbase/traverse" 27 "github.com/grailbio/testutil" 28 "github.com/grailbio/testutil/s3test" 29 ) 30 31 func min(a, b int64) int64 { 32 if a < b { 33 return a 34 } 35 return b 36 } 37 38 func TestDigestReader(t *testing.T) { 39 digester := digest.Digester(crypto.SHA256) 40 41 dataSize := int64(950) 42 segmentSize := int64(100) 43 44 for _, test := range []struct { 45 reader io.Reader 46 order []int64 47 }{ 48 { 49 &testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0}, 50 []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, 51 }, 52 { 53 &testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0}, 54 []int64{1, 0, 3, 2, 5, 4, 7, 6, 9, 8}, 55 }, 56 { 57 &testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0}, 58 []int64{9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, 59 }, 60 } { 61 dra := digester.NewReader(test.reader) 62 readerAt, ok := dra.(io.ReaderAt) 63 if !ok { 64 t.Fatal("reader does not support ReaderAt") 65 } 66 67 err := traverse.Each(len(test.order), func(jobIdx int) error { 68 time.Sleep(10 * time.Duration(jobIdx) * time.Millisecond) 69 index := test.order[jobIdx] 70 size := min(segmentSize, (dataSize-index)*segmentSize) 71 d := make([]byte, size) 72 _, err := readerAt.ReadAt(d, index*int64(segmentSize)) 73 return err 74 }) 75 if err != nil { 76 t.Fatal(err) 77 } 78 79 actual, err := dra.Digest() 80 if err != nil { 81 t.Fatal(err) 82 } 83 84 writer := digester.NewWriter() 85 content := &testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0} 86 if _, err := io.Copy(writer, content); err != nil { 87 t.Fatal(err) 88 } 89 expected := writer.Digest() 90 91 if actual != expected { 92 t.Fatalf("digest mismatch: %s vs %s", actual, expected) 93 } 94 } 95 } 96 97 func TestDigestWriter(t *testing.T) { 98 td, err := ioutil.TempDir("", "grail_cache_test") 99 if err != nil { 100 t.Fatal(err) 101 } 102 defer os.RemoveAll(td) 103 104 digester := digest.Digester(crypto.SHA256) 105 106 tests := [][]int{ 107 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, 108 {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, 109 } 110 for i := 0; i < 50; i++ { 111 tests = append(tests, rand.Perm(10)) 112 } 113 114 for _, test := range tests { 115 testFile := path.Join(td, "testfile") 116 117 output, err := os.Create(testFile) 118 if err != nil { 119 t.Fatal(err) 120 } 121 122 dwa := digester.NewWriterAt(context.Background(), output) 123 124 err = traverse.Each(len(test), func(jobIdx int) error { 125 time.Sleep(5 * time.Duration(jobIdx) * time.Millisecond) 126 i := test[jobIdx] 127 segmentString := strings.Repeat(fmt.Sprintf("%c", 'a'+i), 100) 128 offset := int64(i * len(segmentString)) 129 130 _, e := dwa.WriteAt([]byte(segmentString), offset) 131 return e 132 }) 133 output.Close() 134 if err != nil { 135 t.Fatal(err) 136 } 137 138 expected, err := dwa.Digest() 139 if err != nil { 140 t.Fatal(err) 141 } 142 143 input, err := os.Open(testFile) 144 if err != nil { 145 t.Fatal(err) 146 } 147 148 w := digester.NewWriter() 149 io.Copy(w, input) 150 151 if got := w.Digest(); expected != got { 152 t.Fatalf("expected: %v, got: %v", expected, got) 153 } 154 155 input.Close() 156 } 157 } 158 159 func TestDigestWriterContext(t *testing.T) { 160 f, err := ioutil.TempFile("", "") 161 if err != nil { 162 t.Fatal(err) 163 } 164 defer os.Remove(f.Name()) 165 digester := digest.Digester(crypto.SHA256) 166 ctx, cancel := context.WithCancel(context.Background()) 167 w := digester.NewWriterAt(ctx, f) 168 _, err = w.WriteAt([]byte{1, 2, 3}, 0) 169 if err != nil { 170 t.Fatal(err) 171 } 172 _, err = w.WriteAt([]byte{4, 5, 6}, 3) 173 if err != nil { 174 t.Fatal(err) 175 } 176 // By now we know the looper is up and running. 177 var wg sync.WaitGroup 178 wg.Add(10) 179 for i := int64(0); i < 10; i++ { 180 go func(i int64) { 181 _, err := w.WriteAt([]byte{1}, 100+i) 182 if got, want := err, ctx.Err(); got != want { 183 t.Errorf("got %v, want %v", got, want) 184 } 185 wg.Done() 186 }(i) 187 } 188 cancel() 189 wg.Wait() 190 } 191 192 func TestS3ManagerUpload(t *testing.T) { 193 client := s3test.NewClient(t, "test-bucket") 194 195 size := int64(93384620) // Completely random number. 196 197 digester := digest.Digester(crypto.SHA256) 198 contentAt := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0} 199 client.SetFileContentAt("test/test/test", contentAt, "fakesha") 200 reader := digester.NewReader(contentAt) 201 202 // The relationship between size and PartSize is: 203 // - size/PartSize > 10 to utilize multiple parallel uploads. 204 // - size%PartSize != 0 so the last part is a partial upload. 205 // - size is small enough that the unittest runs quickly. 206 // - PartSize is the minimum allowed. 207 uploader := s3manager.NewUploaderWithClient(client, func(d *s3manager.Uploader) { d.Concurrency = 30; d.PartSize = 5242880 }) 208 209 input := &s3manager.UploadInput{ 210 Bucket: aws.String("test-bucket"), 211 Key: aws.String("test/test/test"), 212 Body: reader, 213 } 214 215 // Perform an upload. 216 _, err := uploader.UploadWithContext(context.Background(), input, func(*s3manager.Uploader) {}) 217 if err != nil { 218 t.Fatal(err) 219 } 220 221 got, err := reader.Digest() 222 if err != nil { 223 t.Fatal(err) 224 } 225 226 dw := digester.NewWriter() 227 content := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0} 228 if _, err := io.Copy(dw, content); err != nil { 229 t.Fatal(err) 230 } 231 expected := dw.Digest() 232 233 if got != expected { 234 t.Fatalf("digest mismatch, expected %s, got %s", expected, got) 235 } 236 } 237 238 func TestS3ManagerDownload(t *testing.T) { 239 client := s3test.NewClient(t, "test-bucket") 240 client.NumMaxRetries = 10 241 242 size := int64(86738922) // Completely random number. 243 244 digester := digest.Digester(crypto.SHA256) 245 contentAt := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0.001} 246 client.SetFileContentAt("test/test/test", contentAt, "fakesha") 247 writer := digester.NewWriterAt(context.Background(), contentAt) 248 249 // The relationship between size and PartSize is: 250 // - size/PartSize > 10 to utilize multiple parallel uploads. 251 // - size%PartSize != 0 so the last part is a partial upload. 252 // - size is small enough that the unittest runs quickly. 253 // - PartSize is the minimum allowed. 254 downloader := s3manager.NewDownloaderWithClient(client, func(d *s3manager.Downloader) { d.Concurrency = 30; d.PartSize = 55242880 }) 255 256 params := &s3.GetObjectInput{ 257 Bucket: aws.String("test-bucket"), 258 Key: aws.String("test/test/test"), 259 } 260 _, err := downloader.DownloadWithContext( 261 context.Background(), 262 writer, 263 params, 264 ) 265 if err != nil { 266 t.Fatal(err) 267 } 268 269 got, err := writer.Digest() 270 if err != nil { 271 t.Fatal(err) 272 } 273 274 dw := digester.NewWriter() 275 content := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0} 276 if _, err := io.Copy(dw, content); err != nil { 277 t.Fatal(err) 278 } 279 expected := dw.Digest() 280 281 if got != expected { 282 t.Fatalf("digest mismatch, expected %s, got %s", expected, got) 283 } 284 }