github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/recordio/deprecated/packed_test.go (about) 1 // Copyright 2017 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 package deprecated_test 6 7 import ( 8 "bytes" 9 "crypto/sha256" 10 "encoding/hex" 11 "encoding/json" 12 "fmt" 13 "math/rand" 14 "reflect" 15 "sync" 16 "testing" 17 "time" 18 19 "github.com/Schaudge/grailbase/recordio/deprecated" 20 "github.com/Schaudge/grailbase/recordio/internal" 21 "github.com/grailbio/testutil" 22 "github.com/grailbio/testutil/assert" 23 "github.com/grailbio/testutil/expect" 24 ) 25 26 func TestPackedWriteRead(t *testing.T) { 27 sl := func(s ...string) [][]byte { 28 var bs [][]byte 29 for _, t := range s { 30 bs = append(bs, []byte(t)) 31 } 32 return bs 33 } 34 read := func(sc deprecated.LegacyScanner) string { 35 if !sc.Scan() { 36 assert.NoError(t, sc.Err()) 37 return "eof" 38 } 39 return string(sc.Bytes()) 40 } 41 42 for _, tc := range []struct { 43 in [][]byte 44 maxItems, maxBytes uint32 45 nreads int 46 }{ 47 {sl(""), 1, 10, 1}, 48 {sl("a", "b"), 2, 10, 1}, 49 {sl("", "", ""), 2, 10, 2}, 50 {sl("hello", "world", "line", "2", "line3"), 2, 100, 3}, 51 {sl("a", "b", "c", "d", "e", "f", "g"), 100, 2, 4}, 52 } { 53 out := &bytes.Buffer{} 54 nflushs := 0 55 wropts := deprecated.LegacyPackedWriterOpts{ 56 MaxItems: tc.maxItems, 57 MaxBytes: tc.maxBytes, 58 Flushed: func() error { nflushs++; return nil }, 59 } 60 wr := deprecated.NewLegacyPackedWriter(out, wropts) 61 62 for _, p := range tc.in { 63 n, err := wr.Write(p) 64 assert.NoError(t, err) 65 assert.True(t, n == len(p)) 66 } 67 assert.NoError(t, wr.Flush()) 68 // Make sure Flushing with nothing to flush has no effect. 69 assert.NoError(t, wr.Flush()) 70 assert.EQ(t, bytes.Count(out.Bytes(), internal.MagicPacked[:]), tc.nreads) 71 assert.EQ(t, nflushs, tc.nreads) 72 assert.EQ(t, bytes.Count(out.Bytes(), internal.MagicLegacyUnpacked[:]), 0) 73 74 sc := deprecated.NewLegacyPackedScanner(bytes.NewReader(out.Bytes()), deprecated.LegacyPackedScannerOpts{}) 75 for _, expected := range tc.in { 76 expect.EQ(t, read(sc), string(expected)) 77 } 78 expect.EQ(t, read(sc), "eof") 79 sc.Reset(bytes.NewReader(out.Bytes())) 80 for _, expected := range tc.in[:1] { 81 expect.EQ(t, read(sc), string(expected)) 82 } 83 sc.Reset(bytes.NewReader(out.Bytes())) 84 for _, expected := range tc.in { 85 expect.EQ(t, read(sc), string(expected)) 86 } 87 expect.EQ(t, read(sc), "eof") 88 } 89 } 90 91 func TestPackedMarshal(t *testing.T) { 92 type js struct { 93 I int 94 IS string 95 } 96 wropts := deprecated.LegacyPackedWriterOpts{ 97 MaxItems: 2, 98 MaxBytes: 0, 99 } 100 wropts.Marshal = func(scratch []byte, v interface{}) ([]byte, error) { 101 return json.Marshal(v) 102 } 103 scopts := deprecated.LegacyPackedScannerOpts{} 104 scopts.Unmarshal = json.Unmarshal 105 for i, tc := range []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100, 233} { 106 out := &bytes.Buffer{} 107 wr := deprecated.NewLegacyPackedWriter(out, wropts) 108 for i := 0; i < tc; i++ { 109 if _, err := wr.Marshal(&js{i, fmt.Sprintf("%d", i)}); err != nil { 110 t.Fatalf("%v: %v", i, err) 111 } 112 } 113 if err := wr.Flush(); err != nil { 114 t.Fatalf("%v: %v", i, err) 115 } 116 sc := deprecated.NewLegacyPackedScanner(out, scopts) 117 data := make([]js, tc) 118 next := 0 119 for sc.Scan() { 120 if err := sc.Unmarshal(&data[next]); err != nil { 121 t.Fatalf("%v: %v", next, err) 122 } 123 next++ 124 } 125 if err := sc.Err(); err != nil { 126 t.Fatalf("%v: %v", i, err) 127 } 128 for i, d := range data { 129 w := &js{i, fmt.Sprintf("%v", i)} 130 if got, want := &d, w; !reflect.DeepEqual(got, want) { 131 t.Errorf("%v: got %v, want %v", i, got, want) 132 } 133 } 134 if got, want := len(data), tc; got != want { 135 t.Errorf("%v: got %v, want %v", i, got, want) 136 } 137 } 138 } 139 func TestPackedMixed(t *testing.T) { 140 type js struct { 141 I int 142 IS string 143 } 144 indexedObjs := 0 145 indexedBufs := 0 146 wropts := deprecated.LegacyPackedWriterOpts{ 147 MaxItems: 2, 148 MaxBytes: 0, 149 } 150 wropts.Marshal = func(scratch []byte, v interface{}) ([]byte, error) { 151 return json.Marshal(v) 152 } 153 wropts.Index = func(record, recordSize, items uint64) (deprecated.ItemIndexFunc, error) { 154 return func(offset, extent uint64, v interface{}, p []byte) error { 155 if p != nil { 156 indexedBufs++ 157 } 158 if v != nil { 159 indexedObjs++ 160 } 161 return nil 162 }, nil 163 } 164 scopts := deprecated.LegacyPackedScannerOpts{} 165 scopts.Unmarshal = json.Unmarshal 166 wantIndexedObjs, wantIndexedBufs := 0, 0 167 for i, tc := range []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100, 233} { 168 out := &bytes.Buffer{} 169 wr := deprecated.NewLegacyPackedWriter(out, wropts) 170 for i := 0; i < tc; i++ { 171 wantIndexedBufs++ 172 if i%2 == 0 { 173 buf, err := json.Marshal(&js{i, fmt.Sprintf("%d", i)}) 174 if err != nil { 175 t.Fatalf("%v: %v", i, err) 176 } 177 wr.Write(buf) 178 } else { 179 if _, err := wr.Marshal(&js{i, fmt.Sprintf("%d", i)}); err != nil { 180 t.Fatalf("%v: %v", i, err) 181 } 182 wantIndexedObjs++ 183 } 184 } 185 if err := wr.Flush(); err != nil { 186 t.Fatalf("%v: %v", i, err) 187 } 188 sc := deprecated.NewLegacyPackedScanner(out, scopts) 189 data := make([]js, tc) 190 next := 0 191 for sc.Scan() { 192 if err := sc.Unmarshal(&data[next]); err != nil { 193 t.Fatalf("%v: %v", next, err) 194 } 195 next++ 196 } 197 if err := sc.Err(); err != nil { 198 t.Fatalf("%v: %v", i, err) 199 } 200 for i, d := range data { 201 w := &js{i, fmt.Sprintf("%v", i)} 202 if got, want := &d, w; !reflect.DeepEqual(got, want) { 203 t.Errorf("%v: got %v, want %v", i, got, want) 204 } 205 } 206 if got, want := len(data), tc; got != want { 207 t.Errorf("%v: got %v, want %v", i, got, want) 208 } 209 } 210 211 if got, want := indexedBufs, wantIndexedBufs; got != want { 212 t.Errorf("got %v, want %v", got, want) 213 } 214 if got, want := indexedObjs, wantIndexedObjs; got != want { 215 t.Errorf("got %v, want %v", got, want) 216 } 217 } 218 219 func TestPackedMax(t *testing.T) { 220 max := deprecated.MaxPackedItems 221 deprecated.MaxPackedItems = 2 222 defer func() { 223 deprecated.MaxPackedItems = max 224 }() 225 out := &bytes.Buffer{} 226 wropts := deprecated.LegacyPackedWriterOpts{ 227 MaxItems: 200, 228 MaxBytes: 0, 229 } 230 wr := deprecated.NewLegacyPackedWriter(out, wropts) 231 wr.Write([]byte("hello 1")) 232 wr.Write([]byte("hello 2")) 233 wr.Write([]byte("hello 3")) 234 wr.Write([]byte("hello 4")) 235 wr.Write([]byte("hello 5")) 236 if err := wr.Flush(); err != nil { 237 t.Fatal(err) 238 } 239 if got, want := bytes.Count(out.Bytes(), internal.MagicPacked[:]), 3; got != want { 240 t.Errorf("got %v, want %v", got, want) 241 } 242 } 243 244 func TestPackedErrors(t *testing.T) { 245 wropts := deprecated.LegacyPackedWriterOpts{ 246 MaxItems: 2, 247 MaxBytes: 10, 248 } 249 buf := &bytes.Buffer{} 250 wr := deprecated.NewLegacyPackedWriter(buf, wropts) 251 252 bigbuf := [50]byte{} 253 _, err := wr.Write(bigbuf[:]) 254 expect.HasSubstr(t, err, "buffer is too large 50 > 10") 255 256 wropts.Marshal = func(scratch []byte, v interface{}) ([]byte, error) { 257 return json.Marshal(v) 258 } 259 wr = deprecated.NewLegacyPackedWriter(buf, wropts) 260 _, err = wr.Marshal(bigbuf[:]) 261 expect.HasSubstr(t, err, "buffer is too large 70 > 10") 262 263 writeError := func(offset int, short bool, msg string) { 264 ew := &fakeWriter{errAt: offset, short: short} 265 wropts := deprecated.LegacyPackedWriterOpts{MaxItems: 1, MaxBytes: 10} 266 wr := deprecated.NewLegacyPackedWriter(ew, wropts) 267 _, err := wr.Write([]byte("hello")) 268 expect.NoError(t, err, "first write succeeds") 269 _, err = wr.Write([]byte("hello")) 270 expect.HasSubstr(t, err, msg) 271 } 272 writeError(0, false, "recordio: failed to write header") 273 writeError(22, false, "recordio: failed to write record") 274 writeError(25, true, "recordio: buffered write too short") 275 276 wropts.Transform = func(bufs [][]byte) ([]byte, error) { 277 return nil, fmt.Errorf("transform oops") 278 } 279 wr = deprecated.NewLegacyPackedWriter(buf, wropts) 280 wr.Write([]byte("oh")) 281 err = wr.Flush() 282 expect.HasSubstr(t, err, "transform oops") 283 284 buf.Reset() 285 wr = deprecated.NewLegacyPackedWriter(buf, deprecated.LegacyPackedWriterOpts{}) 286 wr.Write([]byte("")) 287 if err := wr.Flush(); err != nil { 288 t.Fatal(err) 289 } 290 291 shortReadError := func(offset int, msg string) { 292 f := buf.Bytes() 293 tmp := make([]byte, len(f)) 294 copy(tmp, f) 295 s := deprecated.NewLegacyPackedScanner(bytes.NewBuffer(tmp[:offset]), 296 deprecated.LegacyPackedScannerOpts{}) 297 if s.Scan() { 298 t.Errorf("expected false") 299 } 300 expect.HasSubstr(t, s.Err(), msg) 301 } 302 shortReadError(1, "unexpected EOF") 303 shortReadError(19, "unexpected EOF") 304 shortReadError(20, "short/long record") 305 306 scopts := deprecated.LegacyPackedScannerOpts{} 307 scopts.Transform = func(scratch, buf []byte) ([]byte, error) { 308 return nil, fmt.Errorf("transform oops") 309 } 310 sc := deprecated.NewLegacyPackedScanner(bytes.NewBuffer(buf.Bytes()), scopts) 311 if sc.Scan() { 312 t.Fatal("expected false") 313 } 314 if sc.Scan() { 315 t.Fatal("expect false") 316 } 317 expect.HasSubstr(t, sc.Err(), "transform oops") 318 } 319 320 func readAll(t *testing.T, buf *bytes.Buffer, opts deprecated.LegacyPackedScannerOpts) []string { 321 sc := deprecated.NewLegacyPackedScanner(buf, opts) 322 var read []string 323 for sc.Scan() { 324 read = append(read, string(sc.Bytes())) 325 } 326 assert.NoError(t, sc.Err()) 327 return read 328 } 329 330 func TestPackedTransform(t *testing.T) { 331 // Prepend __ and append ++ to every record. 332 wropts := deprecated.LegacyPackedWriterOpts{ 333 MaxItems: 2, 334 Transform: func(bufs [][]byte) ([]byte, error) { 335 r := []byte("__") 336 for _, b := range bufs { 337 r = append(r, b...) 338 } 339 return append(r, []byte("++")...), nil 340 }, 341 } 342 buf := &bytes.Buffer{} 343 wr := deprecated.NewLegacyPackedWriter(buf, wropts) 344 data := []string{"Hello", "World", "How", "Are", "You?"} 345 for _, d := range data { 346 wr.Write([]byte(d)) 347 } 348 wr.Flush() 349 350 // Scan with the __ and ++ in place. Each item in the each 351 // record is the same original size, but the contents are 352 // 'shifted' by the leading__ 353 expected := []string{"__Hel", "loWor", "__H", "owA", "__Yo"} 354 saved := make([]byte, len(buf.Bytes())) 355 copy(saved, buf.Bytes()) 356 357 read := readAll(t, buf, deprecated.LegacyPackedScannerOpts{}) 358 if got, want := len(expected), len(expected); got != want { 359 t.Errorf("got %v, want %v", got, want) 360 } 361 362 for i, r := range read { 363 if got, want := r, expected[i]; got != want { 364 t.Errorf("%d: got %v, want %v", i, got, want) 365 } 366 } 367 368 scopts := deprecated.LegacyPackedScannerOpts{} 369 // Strip the leading ++ and trailing ++ while scanning 370 scopts.Transform = func(scratch, buf []byte) ([]byte, error) { 371 return buf[2 : len(buf)-2], nil 372 } 373 buf = bytes.NewBuffer(saved) 374 read = readAll(t, buf, scopts) 375 if got, want := len(read), len(data); got != want { 376 t.Errorf("got %v, want %v", got, want) 377 } 378 379 for i, r := range read { 380 if got, want := r, data[i]; got != want { 381 t.Errorf("%d: got %v, want %v", i, got, want) 382 } 383 } 384 } 385 386 func createBufs(nBufs, maxSize int) (map[string][]byte, error) { 387 rand.Seed(time.Now().UnixNano()) 388 out := map[string][]byte{} 389 for i := 0; i < nBufs; i++ { 390 size := rand.Intn(maxSize) 391 if size == 0 { 392 size = maxSize / 2 393 } 394 buf := make([]byte, size) 395 n, err := rand.Read(buf) 396 if err != nil || n != cap(buf) { 397 return nil, fmt.Errorf("failed to generate %d bytes of random data: %d != %d: %v", cap(buf), n, cap(buf), err) 398 } 399 s := sha256.Sum256(buf) 400 k := hex.EncodeToString(s[:]) 401 if _, present := out[k]; present { 402 // avoid dups. 403 i-- 404 continue 405 } 406 out[k] = buf 407 } 408 return out, nil 409 } 410 411 func TestPackedConcurrentWrites(t *testing.T) { 412 nbufs := 200 413 maxBufSize := 1 * 1024 * 1024 414 data, err := createBufs(nbufs, maxBufSize) 415 assert.NoError(t, err) 416 417 tmpdir, cleanup := testutil.TempDir(t, "", "encrypted-test") 418 defer testutil.NoCleanupOnError(t, cleanup, "tmpdir: ", tmpdir) 419 buf := &bytes.Buffer{} 420 wr := deprecated.NewLegacyPackedWriter(buf, deprecated.LegacyPackedWriterOpts{MaxItems: 10}) 421 422 var wg sync.WaitGroup 423 wg.Add(len(data)) 424 ch := make(chan error, nbufs) 425 for _, v := range data { 426 go func(b []byte) { 427 _, err := wr.Write(b) 428 ch <- err 429 wg.Done() 430 }(v) 431 } 432 433 wg.Wait() 434 wr.Flush() 435 close(ch) 436 for err := range ch { 437 assert.NoError(t, err) 438 } 439 440 scanner := deprecated.NewLegacyPackedScanner(buf, deprecated.LegacyPackedScannerOpts{}) 441 for scanner.Scan() { 442 buf := scanner.Bytes() 443 sum := sha256.Sum256(buf) 444 key := hex.EncodeToString(sum[:]) 445 if _, present := data[key]; !present { 446 t.Errorf("corrupt/wrong data %v is not a sha256 of one of the test bufs", key) 447 continue 448 } 449 data[key] = nil 450 } 451 assert.NoError(t, scanner.Err()) 452 453 for k, v := range data { 454 if v != nil { 455 t.Errorf("failed to read buffer with sha256 of %v", k) 456 } 457 } 458 }