ariga.io/entcache@v0.1.1-0.20230620164151-0eb723a11c40/level.go (about)

     1  package entcache
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"encoding/gob"
     8  	"errors"
     9  	"fmt"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/golang/groupcache/lru"
    14  	"github.com/redis/go-redis/v9"
    15  )
    16  
    17  type (
    18  	// Entry defines an entry to store in a cache.
    19  	Entry struct {
    20  		Columns []string
    21  		Values  [][]driver.Value
    22  	}
    23  
    24  	// A Key defines a comparable Go value.
    25  	// See http://golang.org/ref/spec#Comparison_operators
    26  	Key any
    27  
    28  	// AddGetDeleter defines the interface for getting,
    29  	// adding and deleting entries from the cache.
    30  	AddGetDeleter interface {
    31  		Del(context.Context, Key) error
    32  		Add(context.Context, Key, *Entry, time.Duration) error
    33  		Get(context.Context, Key) (*Entry, error)
    34  	}
    35  )
    36  
    37  func init() {
    38  	// Register non builtin driver.Values.
    39  	gob.Register(time.Time{})
    40  }
    41  
    42  // MarshalBinary implements the encoding.BinaryMarshaler interface.
    43  func (e Entry) MarshalBinary() ([]byte, error) {
    44  	entry := struct {
    45  		C []string
    46  		V [][]driver.Value
    47  	}{
    48  		C: e.Columns,
    49  		V: e.Values,
    50  	}
    51  	var buf bytes.Buffer
    52  	if err := gob.NewEncoder(&buf).Encode(entry); err != nil {
    53  		return nil, err
    54  	}
    55  	return buf.Bytes(), nil
    56  }
    57  
    58  // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
    59  func (e *Entry) UnmarshalBinary(buf []byte) error {
    60  	var entry struct {
    61  		C []string
    62  		V [][]driver.Value
    63  	}
    64  	if err := gob.NewDecoder(bytes.NewBuffer(buf)).Decode(&entry); err != nil {
    65  		return err
    66  	}
    67  	e.Values = entry.V
    68  	e.Columns = entry.C
    69  	return nil
    70  }
    71  
    72  // ErrNotFound is returned by Get when and Entry does not exist in the cache.
    73  var ErrNotFound = errors.New("entcache: entry was not found")
    74  
    75  type (
    76  	// LRU provides an LRU cache that implements the AddGetter interface.
    77  	LRU struct {
    78  		mu sync.Mutex
    79  		*lru.Cache
    80  	}
    81  	// entry wraps the Entry with additional expiry information.
    82  	entry struct {
    83  		*Entry
    84  		expiry time.Time
    85  	}
    86  )
    87  
    88  // NewLRU creates a new Cache.
    89  // If maxEntries is zero, the cache has no limit.
    90  func NewLRU(maxEntries int) *LRU {
    91  	return &LRU{
    92  		Cache: lru.New(maxEntries),
    93  	}
    94  }
    95  
    96  // Add adds the entry to the cache.
    97  func (l *LRU) Add(_ context.Context, k Key, e *Entry, ttl time.Duration) error {
    98  	l.mu.Lock()
    99  	defer l.mu.Unlock()
   100  	buf, err := e.MarshalBinary()
   101  	if err != nil {
   102  		return err
   103  	}
   104  	ne := &Entry{}
   105  	if err := ne.UnmarshalBinary(buf); err != nil {
   106  		return err
   107  	}
   108  	if ttl == 0 {
   109  		l.Cache.Add(k, ne)
   110  	} else {
   111  		l.Cache.Add(k, &entry{Entry: ne, expiry: time.Now().Add(ttl)})
   112  	}
   113  	return nil
   114  }
   115  
   116  // Get gets an entry from the cache.
   117  func (l *LRU) Get(_ context.Context, k Key) (*Entry, error) {
   118  	l.mu.Lock()
   119  	e, ok := l.Cache.Get(k)
   120  	l.mu.Unlock()
   121  	if !ok {
   122  		return nil, ErrNotFound
   123  	}
   124  	switch e := e.(type) {
   125  	case *Entry:
   126  		return e, nil
   127  	case *entry:
   128  		if time.Now().Before(e.expiry) {
   129  			return e.Entry, nil
   130  		}
   131  		l.mu.Lock()
   132  		l.Cache.Remove(k)
   133  		l.mu.Unlock()
   134  		return nil, ErrNotFound
   135  	default:
   136  		return nil, fmt.Errorf("entcache: unexpected entry type: %T", e)
   137  	}
   138  }
   139  
   140  // Del deletes an entry from the cache.
   141  func (l *LRU) Del(_ context.Context, k Key) error {
   142  	l.mu.Lock()
   143  	l.Cache.Remove(k)
   144  	l.mu.Unlock()
   145  	return nil
   146  }
   147  
   148  // Redis provides a remote cache backed by Redis
   149  // and implements the SetGetter interface.
   150  type Redis struct {
   151  	c redis.Cmdable
   152  }
   153  
   154  // NewRedis returns a new Redis cache level from the given Redis connection.
   155  //
   156  //	entcache.NewRedis(redis.NewClient(&redis.Options{
   157  //		Addr: ":6379"
   158  //	}))
   159  //
   160  //	entcache.NewRedis(redis.NewClusterClient(&redis.ClusterOptions{
   161  //		Addrs: []string{":7000", ":7001", ":7002"},
   162  //	}))
   163  func NewRedis(c redis.Cmdable) *Redis {
   164  	return &Redis{c: c}
   165  }
   166  
   167  // Add adds the entry to the cache.
   168  func (r *Redis) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error {
   169  	key := fmt.Sprint(k)
   170  	if key == "" {
   171  		return nil
   172  	}
   173  	buf, err := e.MarshalBinary()
   174  	if err != nil {
   175  		return err
   176  	}
   177  	if err := r.c.Set(ctx, key, buf, ttl).Err(); err != nil {
   178  		return err
   179  	}
   180  	return nil
   181  }
   182  
   183  // Get gets an entry from the cache.
   184  func (r *Redis) Get(ctx context.Context, k Key) (*Entry, error) {
   185  	key := fmt.Sprint(k)
   186  	if key == "" {
   187  		return nil, ErrNotFound
   188  	}
   189  	buf, err := r.c.Get(ctx, key).Bytes()
   190  	if err != nil || len(buf) == 0 {
   191  		return nil, ErrNotFound
   192  	}
   193  	e := &Entry{}
   194  	if err := e.UnmarshalBinary(buf); err != nil {
   195  		return nil, err
   196  	}
   197  	return e, nil
   198  }
   199  
   200  // Del deletes an entry from the cache.
   201  func (r *Redis) Del(ctx context.Context, k Key) error {
   202  	key := fmt.Sprint(k)
   203  	if key == "" {
   204  		return nil
   205  	}
   206  	return r.c.Del(ctx, key).Err()
   207  }
   208  
   209  // multiLevel provides a multi-level cache implementation.
   210  type multiLevel struct {
   211  	levels []AddGetDeleter
   212  }
   213  
   214  // Add adds the entry to the cache.
   215  func (m *multiLevel) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error {
   216  	for i := range m.levels {
   217  		if err := m.levels[i].Add(ctx, k, e, ttl); err != nil {
   218  			return err
   219  		}
   220  	}
   221  	return nil
   222  }
   223  
   224  // Get gets an entry from the cache.
   225  func (m *multiLevel) Get(ctx context.Context, k Key) (*Entry, error) {
   226  	for i := range m.levels {
   227  		switch e, err := m.levels[i].Get(ctx, k); {
   228  		case err == nil:
   229  			return e, nil
   230  		case err != ErrNotFound:
   231  			return nil, err
   232  		}
   233  	}
   234  	return nil, ErrNotFound
   235  }
   236  
   237  // Del deletes an entry from the cache.
   238  func (m *multiLevel) Del(ctx context.Context, k Key) error {
   239  	for i := range m.levels {
   240  		if err := m.levels[i].Del(ctx, k); err != nil {
   241  			return err
   242  		}
   243  	}
   244  	return nil
   245  }
   246  
   247  // contextLevel provides a context/request level cache implementation.
   248  type contextLevel struct{}
   249  
   250  // Get gets an entry from the cache.
   251  func (*contextLevel) Get(ctx context.Context, k Key) (*Entry, error) {
   252  	c, ok := FromContext(ctx)
   253  	if !ok {
   254  		return nil, ErrNotFound
   255  	}
   256  	return c.Get(ctx, k)
   257  }
   258  
   259  // Add adds the entry to the cache.
   260  func (*contextLevel) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error {
   261  	c, ok := FromContext(ctx)
   262  	if !ok {
   263  		return nil
   264  	}
   265  	return c.Add(ctx, k, e, ttl)
   266  }
   267  
   268  // Del deletes an entry from the cache.
   269  func (*contextLevel) Del(ctx context.Context, k Key) error {
   270  	c, ok := FromContext(ctx)
   271  	if !ok {
   272  		return nil
   273  	}
   274  	return c.Del(ctx, k)
   275  }