github.com/grailbio/base@v0.0.11/recordio/deprecated/recordio_test.go (about) 1 package deprecated_test 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "os" 9 "reflect" 10 "strings" 11 "testing" 12 13 "github.com/grailbio/base/recordio/deprecated" 14 "github.com/grailbio/base/recordio/internal" 15 "github.com/grailbio/testutil/assert" 16 "github.com/grailbio/testutil/expect" 17 "github.com/klauspost/compress/gzip" 18 "v.io/x/lib/gosh" 19 ) 20 21 func cat(args ...[]byte) []byte { 22 r := []byte{} 23 for _, a := range args { 24 r = append(r, a...) 25 } 26 return r 27 } 28 29 func cmp(a, b []byte) bool { 30 if len(a) != len(b) { 31 return false 32 } 33 for ia, va := range a { 34 if b[ia] != va { 35 return false 36 } 37 } 38 return true 39 } 40 41 func TestRecordioSimpleWriteRead(t *testing.T) { 42 c := func(args []byte, m string) []byte { 43 return cat(internal.MagicLegacyUnpacked[:], args, []byte(m)) 44 } 45 for i, tc := range []struct { 46 in string 47 out []byte 48 }{ 49 {"", c([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0x69, 0xDF, 0x22, 0x65}, "")}, 50 {"a", c([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0xF7, 0xDF, 0x88, 0xA9}, "a")}, 51 {"hello\n", c([]byte{6, 0, 0, 0, 0, 0, 0, 0, 0xEE, 0xD6, 0x4D, 0xA3}, "hello\n")}, 52 } { 53 out := &bytes.Buffer{} 54 rw := deprecated.NewLegacyWriter(out, deprecated.LegacyWriterOpts{}) 55 n, err := rw.Write([]byte(tc.in)) 56 assert.NoError(t, err) 57 assert.EQ(t, n, len(tc.in)) 58 if got, want := out.Bytes(), tc.out; !cmp(got, want) { 59 t.Errorf("%d: got %v, want %v", i, got, want) 60 } 61 62 s := deprecated.NewLegacyScanner(out, deprecated.LegacyScannerOpts{}) 63 assert.True(t, s.Scan()) 64 b := s.Bytes() 65 if got, want := b, []byte(tc.in); !cmp(got, want) { 66 t.Errorf("%d: got %v, want %v", i, got, want) 67 } 68 if got, want := s.Scan(), false; got != want { 69 t.Errorf("%d: got %v, want %v", i, got, want) 70 } 71 if err := s.Err(); err != nil { 72 t.Errorf("%d: %v", i, err) 73 } 74 } 75 76 out := &bytes.Buffer{} 77 rw := deprecated.NewLegacyWriter(out, deprecated.LegacyWriterOpts{}) 78 n, err := rw.WriteSlices([]byte("hello"), []byte(" "), []byte("world")) 79 if err != nil { 80 t.Fatal(err) 81 } 82 if got, want := n, 11; got != want { 83 t.Errorf("got %v, want %v", got, want) 84 } 85 s := deprecated.NewLegacyScanner(out, deprecated.LegacyScannerOpts{}) 86 s.Scan() 87 if got, want := s.Bytes(), []byte("hello world"); !bytes.Equal(got, want) { 88 t.Errorf("got %s, want %s", got, want) 89 } 90 91 n, err = rw.WriteSlices(nil, []byte("hello"), []byte(" "), []byte("world")) 92 if err != nil { 93 t.Fatal(err) 94 } 95 if got, want := n, 11; got != want { 96 t.Errorf("got %v, want %v", got, want) 97 } 98 s = deprecated.NewLegacyScanner(out, deprecated.LegacyScannerOpts{}) 99 s.Scan() 100 if got, want := s.Bytes(), []byte("hello world"); !bytes.Equal(got, want) { 101 t.Errorf("got %s, want %s", got, want) 102 } 103 } 104 105 type fakeWriter struct { 106 errAt int 107 s int 108 short bool // set for a 'short writes' 109 } 110 111 func (fw *fakeWriter) Write(p []byte) (n int, err error) { 112 fw.s += len(p) 113 if fw.s > fw.errAt { 114 if fw.short { 115 return 0, nil 116 } 117 return 0, fmt.Errorf("error at %d", fw.errAt) 118 } 119 return len(p), nil 120 } 121 122 type fakeReader struct { 123 buf *bytes.Buffer 124 errAt int 125 s int 126 } 127 128 func (fr *fakeReader) Read(p []byte) (n int, err error) { 129 b := fr.buf.Next(len(p)) 130 fr.s += len(p) 131 if fr.s > fr.errAt { 132 return 0, fmt.Errorf("fail at %d", fr.errAt) 133 } 134 copy(p, b) 135 return len(b), nil 136 } 137 138 func TestRecordioErrors(t *testing.T) { 139 writeError := func(offset int, msg string) { 140 ew := &fakeWriter{errAt: offset} 141 rw := deprecated.NewLegacyWriter(ew, deprecated.LegacyWriterOpts{}) 142 _, err := rw.Write([]byte("hello")) 143 expect.HasSubstr(t, err, msg) 144 } 145 writeError(0, "failed to write header") 146 writeError(21, "failed to write record") 147 148 marshalError := func(offset int, msg string) { 149 ew := &fakeWriter{errAt: offset} 150 rw := deprecated.NewLegacyWriter(ew, deprecated.LegacyWriterOpts{ 151 Marshal: recordioMarshal, 152 }) 153 _, err := rw.Marshal(&TestPB{"oops"}) 154 expect.HasSubstr(t, err, msg) 155 } 156 marshalError(0, "failed to write header") 157 marshalError(21, "failed to write record") 158 159 buf := &bytes.Buffer{} 160 rw := deprecated.NewLegacyWriter(buf, deprecated.LegacyWriterOpts{}) 161 rw.Write([]byte("hello\n")) 162 163 corruptionError := func(offset int, msg string) { 164 f := buf.Bytes() 165 tmp := make([]byte, len(f)) 166 copy(tmp, f) 167 tmp[offset] = 0xff 168 s := deprecated.NewLegacyScanner(bytes.NewBuffer(tmp), deprecated.LegacyScannerOpts{}) 169 if s.Scan() { 170 t.Errorf("expected false") 171 } 172 expect.HasSubstr(t, s.Err(), msg) 173 } 174 175 corruptionError(0, "invalid magic number") 176 corruptionError(10, "crc check failed") 177 corruptionError(17, "crc check failed") 178 179 shortReadError := func(offset int, msg string) { 180 f := buf.Bytes() 181 tmp := make([]byte, len(f)) 182 copy(tmp, f) 183 s := deprecated.NewLegacyScanner(bytes.NewBuffer(tmp[:offset]), deprecated.LegacyScannerOpts{}) 184 if s.Scan() { 185 t.Errorf("expected false") 186 } 187 expect.HasSubstr(t, s.Err(), msg) 188 } 189 shortReadError(1, "unexpected EOF") 190 shortReadError(19, "unexpected EOF") 191 shortReadError(20, "short/long record") 192 193 readError := func(offset int, msg string) { 194 f := buf.Bytes() 195 tmp := make([]byte, len(f)) 196 copy(tmp, f) 197 rdr := &fakeReader{buf: bytes.NewBuffer(tmp), errAt: offset} 198 s := deprecated.NewLegacyScanner(rdr, deprecated.LegacyScannerOpts{}) 199 if s.Scan() { 200 t.Errorf("expected false") 201 } 202 expect.HasSubstr(t, s.Err(), msg) 203 } 204 readError(1, "failed to read header") 205 readError(9, "failed to read header") 206 readError(19, "failed to read header") 207 readError(20, "failed to read record") 208 209 defer func(oldSize uint64) { 210 internal.MaxReadRecordSize = oldSize 211 }(internal.MaxReadRecordSize) 212 internal.MaxReadRecordSize = 100 213 buf.Reset() 214 rw = deprecated.NewLegacyWriter(buf, deprecated.LegacyWriterOpts{}) 215 rw.Write([]byte(strings.Repeat("a", 101))) 216 s := deprecated.NewLegacyScanner(buf, deprecated.LegacyScannerOpts{}) 217 if got, want := s.Scan(), false; got != want { 218 t.Errorf("got %v, want %v", got, want) 219 } 220 expect.HasSubstr(t, s.Err(), "unreasonably large read") 221 222 ew := &fakeWriter{errAt: 25} 223 wr := deprecated.NewLegacyWriter(ew, deprecated.LegacyWriterOpts{}) 224 _, err := wr.WriteSlices([]byte("hello"), []byte(" "), []byte("world")) 225 expect.HasSubstr(t, err, "recordio: failed to write record") 226 } 227 228 func TestRecordioEmpty(t *testing.T) { 229 out := &bytes.Buffer{} 230 s := deprecated.NewLegacyScanner(out, deprecated.LegacyScannerOpts{}) 231 if got, want := s.Scan(), false; got != want { 232 t.Errorf("got %v, want %v", got, want) 233 } 234 if s.Bytes() != nil { 235 t.Errorf("expected nil slice") 236 } 237 if got, want := s.Scan(), false; got != want { 238 t.Errorf("got %v, want %v", got, want) 239 } 240 if err := s.Err(); err != nil { 241 t.Errorf("%v", err) 242 } 243 } 244 245 func TestRecordioMultiple(t *testing.T) { 246 out := &bytes.Buffer{} 247 rw := deprecated.NewLegacyWriter(out, deprecated.LegacyWriterOpts{}) 248 expected := []string{"", "hello", "world", "", "last record"} 249 for _, str := range expected { 250 n, err := rw.Write([]byte(str)) 251 assert.NoError(t, err) 252 assert.True(t, n == len(str)) 253 } 254 255 read := func(s deprecated.LegacyScanner) string { 256 if !s.Scan() { 257 expect.NoError(t, s.Err()) 258 return "eof" 259 } 260 return string(s.Bytes()) 261 } 262 s := deprecated.NewLegacyScanner(bytes.NewReader(out.Bytes()), deprecated.LegacyScannerOpts{}) 263 for _, str := range expected { 264 expect.EQ(t, read(s), str) 265 } 266 expect.EQ(t, "eof", read(s)) 267 268 s.Reset(bytes.NewReader(out.Bytes())) 269 for _, str := range expected[:2] { 270 expect.EQ(t, read(s), str) 271 } 272 s.Reset(bytes.NewReader(out.Bytes())) 273 for _, str := range expected { 274 expect.EQ(t, read(s), str) 275 } 276 expect.EQ(t, read(s), "eof") 277 } 278 279 type lazyReader struct { 280 rd io.Reader 281 } 282 283 func (lz *lazyReader) Read(p []byte) (n int, err error) { 284 lazy := len(p) 285 if lazy > 10 { 286 lazy -= 10 287 } 288 return io.ReadFull(lz.rd, p[:lazy]) 289 } 290 291 func TestRecordioWriteRead(t *testing.T) { 292 gosh := gosh.NewShell(t) 293 f := gosh.MakeTempFile() 294 rw := deprecated.NewLegacyWriter(f, deprecated.LegacyWriterOpts{}) 295 contents := []string{"hello", "world", "!"} 296 for _, rec := range contents { 297 rw.Write([]byte(rec)) 298 } 299 name := f.Name() 300 f.Close() 301 302 o := gosh.Cmd("gzip", "--keep", name).CombinedOutput() 303 if got, want := o, ""; got != want { 304 t.Errorf("got %v, want %v", got, want) 305 } 306 307 readall := func(rd io.Reader) ([][]byte, error) { 308 r := [][]byte{} 309 s := deprecated.NewLegacyScanner(rd, deprecated.LegacyScannerOpts{}) 310 for s.Scan() { 311 r = append(r, s.Bytes()) 312 } 313 return r, s.Err() 314 } 315 316 for _, n := range []string{name} { 317 rd, err := os.Open(n) 318 if err != nil { 319 t.Fatalf("%v: %v", n, err) 320 } 321 gz, err := os.Open(n + ".gz") 322 if err != nil { 323 t.Fatalf("%v: %v", n, err) 324 } 325 gzrd, err := gzip.NewReader(gz) 326 if err != nil { 327 t.Fatalf("%v: %v", n, err) 328 } 329 330 raw, err := readall(rd) 331 if err != nil { 332 t.Fatalf("%v (raw): read %d records, %v", n, len(raw), err) 333 } 334 335 compressed, err := readall(gzrd) 336 if err != nil { 337 t.Fatalf("%v (gzip): read %d records (%d raw), %v", n, len(compressed), len(raw), err) 338 } 339 340 buf, err := ioutil.ReadFile(n) 341 if err != nil { 342 t.Fatal(err) 343 } 344 lzr := &lazyReader{rd: bytes.NewBuffer(buf)} 345 lazy, err := readall(lzr) 346 if err != nil { 347 t.Fatalf("%v (lazyreader): read %d records (%d raw), %v", n, len(lazy), len(raw), err) 348 349 } 350 if got, want := len(raw), len(compressed); got != want { 351 t.Errorf("%v: got %v, want %v", n, got, want) 352 } 353 if got, want := raw, compressed; !reflect.DeepEqual(got, want) { 354 t.Errorf("%v: got %v, want %v", n, got, want) 355 } 356 if got, want := raw, lazy; !reflect.DeepEqual(got, want) { 357 t.Errorf("%v: got %v, want %v", n, got, want) 358 } 359 } 360 }