github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/internal/slicecache/slicecache.go (about)

     1  package slicecache
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"runtime"
     7  
     8  	"github.com/grailbio/base/file"
     9  	"github.com/grailbio/base/traverse"
    10  	"github.com/grailbio/bigslice/sliceio"
    11  )
    12  
    13  // Cacheable indicates a slice's data should be cached.
    14  type Cacheable interface {
    15  	Cache() ShardCache
    16  }
    17  
    18  // ShardCache accesses cached data for a slice's shards.
    19  type ShardCache interface {
    20  	IsCached(shard int) bool
    21  	WritethroughReader(shard int, reader sliceio.Reader) sliceio.Reader
    22  	CacheReader(shard int) sliceio.Reader
    23  }
    24  
    25  // Empty is an empty cache.
    26  var Empty ShardCache = empty{}
    27  
    28  type empty struct{}
    29  
    30  func (empty) IsCached(shard int) bool { return false }
    31  func (empty) WritethroughReader(shard int, reader sliceio.Reader) sliceio.Reader {
    32  	return reader
    33  }
    34  func (empty) CacheReader(shard int) sliceio.Reader { panic("always empty") }
    35  
    36  // FileShardCache is a ShardCache backed by files. A nil *FileShardCache has no
    37  // cached data.
    38  type FileShardCache struct {
    39  	prefix        string
    40  	numShards     int
    41  	shardIsCached []bool
    42  	requireAll    bool
    43  }
    44  
    45  const (
    46  	// pathFormat is the format used for the path of cache files.
    47  	pathFormat = "%s-%04d-of-%04d"
    48  	// pathFormatAllShards is the format used to refer to all of a cache's
    49  	// files, for human consumption.
    50  	pathFormatAllShards = "%s-NNNN-of-%04d"
    51  )
    52  
    53  // NewShardCache constructs a ShardCache. It does O(numShards) parallelized
    54  // file operations to look up what's present in the cache.
    55  func NewFileShardCache(ctx context.Context, prefix string, numShards int) *FileShardCache {
    56  	if prefix == "" {
    57  		return &FileShardCache{}
    58  	}
    59  	// TODO(jcharumilind): Make this initialization more lazy. This is generally
    60  	// called within Funcs, but its result is generally ignored on workers to
    61  	// ensure a consistent view of the cache for consistent compilation.
    62  	c := FileShardCache{prefix, numShards, make([]bool, numShards), false}
    63  	_ = traverse.Limit(10*runtime.NumCPU()).Each(numShards, func(shard int) error {
    64  		_, err := file.Stat(ctx, c.path(shard))
    65  		c.shardIsCached[shard] = err == nil // treat lookup errors as cache misses
    66  		return nil
    67  	})
    68  	return &c
    69  }
    70  
    71  func (c *FileShardCache) path(shard int) string {
    72  	return fmt.Sprintf(pathFormat, c.prefix, shard, c.numShards)
    73  }
    74  
    75  func (c *FileShardCache) IsCached(shard int) bool {
    76  	if c == nil {
    77  		return false
    78  	}
    79  	return c.shardIsCached[shard]
    80  }
    81  
    82  func (c *FileShardCache) RequireAllCached() {
    83  	if c == nil {
    84  		return
    85  	}
    86  	c.requireAll = true
    87  	for _, b := range c.shardIsCached {
    88  		if !b {
    89  			for i := range c.shardIsCached {
    90  				c.shardIsCached[i] = false
    91  			}
    92  			return
    93  		}
    94  	}
    95  }
    96  
    97  // WritethroughReader returns a reader that populates the cache. reader should
    98  // read computed data.
    99  func (c *FileShardCache) WritethroughReader(shard int, reader sliceio.Reader) sliceio.Reader {
   100  	if c == nil {
   101  		return reader
   102  	}
   103  	return newWritethroughReader(reader, c.path(shard))
   104  }
   105  
   106  // CacheReader returns a reader that reads from the cache. If the shard is not
   107  // cached, returns a reader that will always return an error.
   108  func (c *FileShardCache) CacheReader(shard int) sliceio.Reader {
   109  	if !c.shardIsCached[shard] {
   110  		path := c.path(shard)
   111  		if c.requireAll {
   112  			path = fmt.Sprintf(pathFormatAllShards, c.prefix, c.numShards)
   113  		}
   114  		err := fmt.Errorf("cache %q invalid for shard %d(%d); check %q",
   115  			c.prefix, shard, c.numShards, path)
   116  		return sliceio.ErrReader(err)
   117  	}
   118  	return newFileReader(c.path(shard))
   119  }