github.com/grailbio/base@v0.0.11/recordio/scannerv2.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 package recordio 6 7 import ( 8 "encoding/binary" 9 "fmt" 10 "io" 11 "sync" 12 13 "github.com/grailbio/base/errors" 14 "github.com/grailbio/base/recordio/internal" 15 ) 16 17 var scannerFreePool = sync.Pool{ 18 New: func() interface{} { 19 return &scannerv2{} 20 }, 21 } 22 23 // rawItemList is the result of uncompressing & parsing one recordio block. 24 type rawItemList struct { 25 bytes []byte // raw bytes, post transformation. 26 firstOff int // bytes[firstOff:] contain the application payload 27 cumSize []int // cumSize[x] is the cumulative bytesize of items [0,x]. 28 } 29 30 func (ri *rawItemList) clear() { 31 ri.bytes = ri.bytes[:0] 32 ri.cumSize = ri.cumSize[:0] 33 ri.firstOff = 0 34 } 35 36 // len returns the number of items in the block. 37 func (ri *rawItemList) len() int { return len(ri.cumSize) } 38 39 // item returns the i'th (base 0) item. 40 // 41 // REQUIRES: 0 <= i < ri.len(). 42 func (ri *rawItemList) item(i int) []byte { 43 startOff := ri.firstOff 44 if i > 0 { 45 startOff += ri.cumSize[i-1] 46 } 47 limitOff := ri.firstOff + ri.cumSize[i] 48 return ri.bytes[startOff:limitOff] 49 } 50 51 // Given block contents, apply transformation if any, and parse it into a list 52 // of items. If transform is nil, it defaults to identity. 53 func parseChunksToItems(rawItems *rawItemList, chunks [][]byte, transform TransformFunc) error { 54 if transform == nil { 55 // TODO(saito) Allow TransformFunc to return an iov, and refactor the rest 56 // of the codebase to consume it. 57 transform = idTransform 58 } 59 var err error 60 if rawItems.bytes != nil { 61 // zstd doesn't like an empty slice (zstd.go:100) 62 // 63 // TODO(saito) fix upstream. 64 rawItems.bytes = rawItems.bytes[:cap(rawItems.bytes)] 65 } 66 if rawItems.bytes, err = transform(rawItems.bytes, chunks); err != nil { 67 return err 68 } 69 block := rawItems.bytes 70 unItems, n := binary.Uvarint(block) 71 if n <= 0 { 72 return fmt.Errorf("recordio: failed to read number of packed items: %v", n) 73 } 74 nItems := int(unItems) 75 pos := n 76 77 if cap(rawItems.cumSize) < nItems { 78 rawItems.cumSize = make([]int, nItems) 79 } else { 80 rawItems.cumSize = rawItems.cumSize[:nItems] 81 } 82 total := 0 83 for i := 0; i < nItems; i++ { 84 size, n := binary.Uvarint(block[pos:]) 85 if n <= 0 { 86 return fmt.Errorf("recordio: likely corrupt data, failed to read size of packed item %v: %v", i, n) 87 } 88 total += int(size) 89 rawItems.cumSize[i] = total 90 pos += n 91 } 92 rawItems.firstOff = pos 93 if total+pos != len(block) { 94 return fmt.Errorf("recordio: corrupt block header, got block size %d, expected %d", len(block), total+pos) 95 } 96 return nil 97 } 98 99 // ScannerOpts defines options used when creating a new scanner. 100 type ScannerOpts struct { 101 // LegacyTransform is used only to read the legacy recordio files. For the V2 102 // recordio files, this field is ignored, and transformers are constructed 103 // from the header metadata. 104 LegacyTransform TransformFunc 105 106 // Unmarshal transforms a byte slice into an application object. It is called 107 // for every item read from storage. If nil, a function that returns []byte 108 // unchanged is used. The return value from Unmarshal can be retrieved using 109 // the Scanner.Get method. 110 Unmarshal func(in []byte) (out interface{}, err error) 111 } 112 113 // Scanner defines an interface for recordio scanner. 114 // 115 // A Scanner implementation must be thread safe. Legal path expression is 116 // defined below. Err, Header, and Trailer can be called at any time. 117 // 118 // ((Scan Get*) | Seek)* Finish 119 // 120 type Scanner interface { 121 // Header returns the contents of the header block. 122 Header() ParsedHeader 123 124 // Scan returns true if a new record was read, false otherwise. It will return 125 // false on encountering an error; the error may be retrieved using the Err 126 // method. Note, that Scan will reuse storage from one invocation to the next. 127 Scan() bool 128 129 // Get returns the current item as read by a prior call to Scan. 130 // 131 // REQUIRES: Preceding Scan calls have returned true. There is no Seek 132 // call between the last Scan call and the Get call. 133 Get() interface{} 134 135 // Err returns any error encountered by the writer. Once Err() becomes 136 // non-nil, it stays so. 137 Err() error 138 139 // Set up so that the next Scan() call causes the pointer to move to the given 140 // location. On any error, Err() will be set. 141 // 142 // REQUIRES: loc must be one of the values passed to the Index callback 143 // during writes. 144 Seek(loc ItemLocation) 145 146 // Trailer returns the trailer block contents. If the trailer does not exist, 147 // or is corrupt, it returns nil. The caller should examine Err() if Trailer 148 // returns nil. 149 Trailer() []byte 150 151 // Return the file format version. Not for general use. 152 Version() FormatVersion 153 154 // Finish should be called exactly once, after the application has finished 155 // using the scanner. It returns the value of Err(). 156 // 157 // The Finish method recycles the internal scanner resources for use by other 158 // scanners, thereby reducing GC overhead. THe application must not touch the 159 // scanner object after Finish. 160 Finish() error 161 } 162 163 type scannerv2 struct { 164 err errors.Once 165 sc *internal.ChunkScanner 166 opts ScannerOpts 167 untransform TransformFunc 168 header ParsedHeader 169 170 rawItems rawItemList 171 item interface{} 172 nextItem int 173 } 174 175 func idUnmarshal(data []byte) (interface{}, error) { 176 return data, nil 177 } 178 179 type errorScanner struct { 180 err error 181 } 182 183 func (s *errorScanner) Header() (p ParsedHeader) { return } 184 func (s *errorScanner) Trailer() (b []byte) { return } 185 func (s *errorScanner) Version() (v FormatVersion) { return } 186 func (s *errorScanner) Get() interface{} { panic(fmt.Sprintf("errorscannerv2.Get: %v", s.err)) } 187 func (s *errorScanner) Scan() bool { return false } 188 func (s *errorScanner) Seek(ItemLocation) {} 189 func (s *errorScanner) Finish() error { return s.Err() } 190 func (s *errorScanner) Err() error { 191 if s.err == io.EOF { 192 return nil 193 } 194 return s.err 195 } 196 197 // NewScanner creates a new recordio scanner. The reader can read both legacy 198 // recordio files (packed or unpacked) or the new-format files. Any error is 199 // reported through the Scanner.Err method. 200 func NewScanner(in io.ReadSeeker, opts ScannerOpts) Scanner { 201 return NewShardScanner(in, opts, 0, 1, 1) 202 } 203 204 // NewShardScanner creates a new sharded recordio scanner. The returned scanner 205 // reads shard [start,limit) (of [0,nshard)) of the recordio file at the 206 // ReadSeeker in. Sharding is only supported for v2 recordio files; an error 207 // scanner is returned if NewShardScanner is called for a legacy recordio file. 208 // 209 // NewShardScanner with shard and nshard set to 0 and 1 respectively (i.e., 210 // a single shard) behaves as NewScanner. 211 func NewShardScanner(in io.ReadSeeker, opts ScannerOpts, start, limit, nshard int) Scanner { 212 if opts.Unmarshal == nil { 213 opts.Unmarshal = idUnmarshal 214 } 215 if err := internal.Seek(in, 0); err != nil { 216 return &errorScanner{err} 217 } 218 var magic internal.MagicBytes 219 if _, err := io.ReadFull(in, magic[:]); err != nil { 220 return &errorScanner{err} 221 } 222 if err := internal.Seek(in, 0); err != nil { 223 return &errorScanner{err} 224 } 225 if start >= limit || limit > nshard || start < 0 || nshard <= 0 { 226 return &errorScanner{fmt.Errorf("invalid sharding [%d,%d) of %d", start, limit, nshard)} 227 } 228 if magic != internal.MagicHeader { 229 if start != 0 || limit != 1 || nshard != 1 { 230 return &errorScanner{errors.New("legacy record IOs do not support sharding")} 231 } 232 return newLegacyScannerAdapter(in, opts) 233 } 234 return newScanner(in, start, limit, nshard, opts) 235 } 236 237 func newScanner(in io.ReadSeeker, start, limit, nshard int, opts ScannerOpts) Scanner { 238 s := scannerFreePool.Get().(*scannerv2) 239 if s == nil { 240 panic("newScannerV2") 241 } 242 s.err = errors.Once{Ignored: []error{io.EOF}} 243 s.opts = opts 244 s.untransform = nil 245 s.header = nil 246 s.nextItem = 0 247 s.item = nil 248 s.sc = internal.NewChunkScanner(in, &s.err) 249 s.rawItems.clear() 250 s.readHeader() 251 if s.Err() != nil { 252 return s 253 } 254 // Technically, we shouldn't be reading the trailer again, but 255 // the block scanner just ignores it anyway. 256 s.sc.LimitShard(start, limit, nshard) 257 return s 258 } 259 260 func (s *scannerv2) readSpecialBlock(expectedMagic internal.MagicBytes, tr TransformFunc) []byte { 261 if !s.sc.Scan() { 262 s.err.Set(fmt.Errorf("Failed to read block %v", expectedMagic)) 263 return nil 264 } 265 magic, chunks := s.sc.Block() 266 if magic != expectedMagic { 267 s.err.Set(fmt.Errorf("Failed to read block, expect %v, got %v", expectedMagic, magic)) 268 return nil 269 } 270 rawItems := rawItemList{} 271 err := parseChunksToItems(&rawItems, chunks, tr) 272 if err != nil { 273 s.err.Set(err) 274 return nil 275 } 276 if rawItems.len() != 1 { 277 s.err.Set(fmt.Errorf("Wrong # of items in header block, %d", rawItems.len())) 278 return nil 279 } 280 return rawItems.item(0) 281 } 282 283 func (s *scannerv2) readHeader() { 284 payload := s.readSpecialBlock(internal.MagicHeader, idTransform) 285 if s.err.Err() != nil { 286 return 287 } 288 if err := s.header.unmarshal(payload); err != nil { 289 s.err.Set(err) 290 return 291 } 292 transformers := []string{} 293 for _, h := range s.header { 294 if h.Key == KeyTransformer { 295 str, ok := h.Value.(string) 296 if !ok { 297 s.err.Set(fmt.Errorf("Expect string value for key %v, but found %v", h.Key, h.Value)) 298 return 299 } 300 transformers = append(transformers, str) 301 } 302 } 303 var err error 304 s.untransform, err = registry.GetUntransformer(transformers) 305 s.err.Set(err) 306 } 307 308 func (s *scannerv2) Version() FormatVersion { 309 return V2 310 } 311 312 func (s *scannerv2) Header() ParsedHeader { 313 return s.header 314 } 315 316 func (s *scannerv2) Trailer() []byte { 317 if !s.header.HasTrailer() { 318 return nil 319 } 320 curOff := s.sc.Tell() 321 defer s.sc.Seek(curOff) 322 323 magic, chunks := s.sc.ReadLastBlock() 324 if s.err.Err() != nil { 325 return nil 326 } 327 if magic != internal.MagicTrailer { 328 s.err.Set(fmt.Errorf("Did not found the trailer, instead found magic %v", magic)) 329 return nil 330 } 331 rawItems := rawItemList{} 332 err := parseChunksToItems(&rawItems, chunks, s.untransform) 333 if err != nil { 334 s.err.Set(err) 335 return nil 336 } 337 if rawItems.len() != 1 { 338 s.err.Set(fmt.Errorf("Expect exactly one trailer item, but found %d", rawItems.len())) 339 return nil 340 } 341 return rawItems.item(0) 342 } 343 344 func (s *scannerv2) Get() interface{} { 345 return s.item 346 } 347 348 func (s *scannerv2) Seek(loc ItemLocation) { 349 // TODO(saito) Avoid seeking the file if loc.Block points to the current block. 350 if s.err.Err() == io.EOF { 351 s.err = errors.Once{} 352 } 353 s.sc.Seek(int64(loc.Block)) 354 if !s.scanNextBlock() { 355 return 356 } 357 if loc.Item >= s.rawItems.len() { 358 s.err.Set(fmt.Errorf("Invalid location %+v, block has only %d items", loc, s.rawItems.len())) 359 } 360 s.nextItem = loc.Item 361 } 362 363 func (s *scannerv2) scanNextBlock() bool { 364 s.rawItems.clear() 365 s.nextItem = 0 366 if s.Err() != nil { 367 return false 368 } 369 // Need to read the next record. 370 if !s.sc.Scan() { 371 return false 372 } 373 magic, chunks := s.sc.Block() 374 if magic == internal.MagicPacked { 375 if err := parseChunksToItems(&s.rawItems, chunks, s.untransform); err != nil { 376 s.err.Set(err) 377 return false 378 } 379 s.nextItem = 0 380 return true 381 } 382 if magic == internal.MagicTrailer { 383 // EOF 384 return false 385 } 386 s.err.Set(fmt.Errorf("recordio: invalid magic number: %v", magic)) 387 return false 388 } 389 390 func (s *scannerv2) Scan() bool { 391 for s.nextItem >= s.rawItems.len() { 392 if !s.scanNextBlock() { 393 return false 394 } 395 } 396 item, err := s.opts.Unmarshal(s.rawItems.item(s.nextItem)) 397 if err != nil { 398 s.err.Set(err) 399 return false 400 } 401 s.item = item 402 s.nextItem++ 403 return true 404 } 405 406 func (s *scannerv2) Err() error { 407 err := s.err.Err() 408 if err == io.EOF { 409 err = nil 410 } 411 return err 412 } 413 414 func (s *scannerv2) Finish() error { 415 err := s.Err() 416 s.err = errors.Once{} 417 s.opts = ScannerOpts{} 418 s.sc = nil 419 s.untransform = nil 420 s.header = nil 421 s.nextItem = 0 422 s.item = nil 423 scannerFreePool.Put(s) 424 return err 425 }