github.com/DeltaLaboratory/entcache@v0.1.1/level.go (about)

     1  package entcache
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"encoding/gob"
     8  	"errors"
     9  	"fmt"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/golang/groupcache/lru"
    14  	"github.com/redis/rueidis"
    15  )
    16  
    17  type (
    18  	// A Key defines a comparable Go value.
    19  	// See http://golang.org/ref/spec#Comparison_operators
    20  	Key any
    21  
    22  	// AddGetDeleter defines the interface for getting,
    23  	// adding and deleting entries from the cache.
    24  	AddGetDeleter interface {
    25  		Del(context.Context, Key) error
    26  		Add(context.Context, Key, *Entry, time.Duration) error
    27  		Get(context.Context, Key) (*Entry, error)
    28  	}
    29  )
    30  
    31  type Entry struct {
    32  	Columns []string         `cbor:"0,keyasint" json:"c" bson:"c"`
    33  	Values  [][]driver.Value `cbor:"1,keyasint" json:"v" bson:"v"`
    34  }
    35  
    36  // MarshalBinary implements the encoding.BinaryMarshaler interface.
    37  func (e Entry) MarshalBinary() ([]byte, error) {
    38  	entry := struct {
    39  		C []string
    40  		V [][]driver.Value
    41  	}{
    42  		C: e.Columns,
    43  		V: e.Values,
    44  	}
    45  	var buf bytes.Buffer
    46  	if err := gob.NewEncoder(&buf).Encode(entry); err != nil {
    47  		return nil, err
    48  	}
    49  	return buf.Bytes(), nil
    50  }
    51  
    52  // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
    53  func (e *Entry) UnmarshalBinary(buf []byte) error {
    54  	var entry struct {
    55  		C []string
    56  		V [][]driver.Value
    57  	}
    58  	if err := gob.NewDecoder(bytes.NewBuffer(buf)).Decode(&entry); err != nil {
    59  		return err
    60  	}
    61  	e.Values = entry.V
    62  	e.Columns = entry.C
    63  	return nil
    64  }
    65  
    66  // ErrNotFound returned by Get when and Entry does not exist in the cache.
    67  var ErrNotFound = errors.New("entcache: entry was not found")
    68  
    69  type (
    70  	// LRU provides an LRU cache that implements the AddGetter interface.
    71  	LRU struct {
    72  		mu sync.Mutex
    73  		*lru.Cache
    74  	}
    75  	// entry wraps the Entry with additional expiry information.
    76  	entry struct {
    77  		*Entry
    78  		expiry time.Time
    79  	}
    80  )
    81  
    82  // NewLRU creates a new Cache.
    83  // If maxEntries is zero, the cache has no limit.
    84  func NewLRU(maxEntries int) *LRU {
    85  	return &LRU{
    86  		Cache: lru.New(maxEntries),
    87  	}
    88  }
    89  
    90  // Add adds the entry to the cache.
    91  func (l *LRU) Add(_ context.Context, k Key, e *Entry, ttl time.Duration) error {
    92  	l.mu.Lock()
    93  	defer l.mu.Unlock()
    94  	buf, err := e.MarshalBinary()
    95  	if err != nil {
    96  		return err
    97  	}
    98  	ne := &Entry{}
    99  	if err := ne.UnmarshalBinary(buf); err != nil {
   100  		return err
   101  	}
   102  	if ttl == 0 {
   103  		l.Cache.Add(k, ne)
   104  	} else {
   105  		l.Cache.Add(k, &entry{Entry: ne, expiry: time.Now().Add(ttl)})
   106  	}
   107  	return nil
   108  }
   109  
   110  // Get gets an entry from the cache.
   111  func (l *LRU) Get(_ context.Context, k Key) (*Entry, error) {
   112  	l.mu.Lock()
   113  	e, ok := l.Cache.Get(k)
   114  	l.mu.Unlock()
   115  	if !ok {
   116  		return nil, ErrNotFound
   117  	}
   118  	switch e := e.(type) {
   119  	case *Entry:
   120  		return e, nil
   121  	case *entry:
   122  		if time.Now().Before(e.expiry) {
   123  			return e.Entry, nil
   124  		}
   125  		l.mu.Lock()
   126  		l.Cache.Remove(k)
   127  		l.mu.Unlock()
   128  		return nil, ErrNotFound
   129  	default:
   130  		return nil, fmt.Errorf("entcache: unexpected entry type: %T", e)
   131  	}
   132  }
   133  
   134  // Del deletes an entry from the cache.
   135  func (l *LRU) Del(_ context.Context, k Key) error {
   136  	l.mu.Lock()
   137  	l.Cache.Remove(k)
   138  	l.mu.Unlock()
   139  	return nil
   140  }
   141  
   142  // Redis provides a remote cache backed by Redis
   143  // and implements the SetGetter interface.
   144  type Redis struct {
   145  	c rueidis.Client
   146  }
   147  
   148  // NewRedis returns a new Redis cache level from the given Redis connection.
   149  func NewRedis(c rueidis.Client) *Redis {
   150  	return &Redis{c: c}
   151  }
   152  
   153  // Add adds the entry to the cache.
   154  func (r *Redis) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error {
   155  	key := fmt.Sprint(k)
   156  	if key == "" {
   157  		return nil
   158  	}
   159  	buf, err := e.MarshalBinary()
   160  	if err != nil {
   161  		return err
   162  	}
   163  	return r.c.Do(ctx, r.c.B().Set().Key(key).Value(rueidis.BinaryString(buf)).Ex(ttl).Build()).Error()
   164  }
   165  
   166  // Get gets an entry from the cache.
   167  func (r *Redis) Get(ctx context.Context, k Key) (*Entry, error) {
   168  	key := fmt.Sprint(k)
   169  	if key == "" {
   170  		return nil, ErrNotFound
   171  	}
   172  	buf, err := r.c.Do(ctx, r.c.B().Get().Key(key).Build()).AsBytes()
   173  	if err != nil || len(buf) == 0 {
   174  		return nil, ErrNotFound
   175  	}
   176  	e := &Entry{}
   177  	if err := e.UnmarshalBinary(buf); err != nil {
   178  		return nil, err
   179  	}
   180  	return e, nil
   181  }
   182  
   183  // Del deletes an entry from the cache.
   184  func (r *Redis) Del(ctx context.Context, k Key) error {
   185  	key := fmt.Sprint(k)
   186  	if key == "" {
   187  		return nil
   188  	}
   189  	return r.c.Do(ctx, r.c.B().Del().Key(key).Build()).Error()
   190  }
   191  
   192  // multiLevel provides a multi-level cache implementation.
   193  type multiLevel struct {
   194  	levels []AddGetDeleter
   195  }
   196  
   197  // Add adds the entry to the cache.
   198  func (m *multiLevel) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error {
   199  	for i := range m.levels {
   200  		if err := m.levels[i].Add(ctx, k, e, ttl); err != nil {
   201  			return err
   202  		}
   203  	}
   204  	return nil
   205  }
   206  
   207  // Get gets an entry from the cache.
   208  func (m *multiLevel) Get(ctx context.Context, k Key) (*Entry, error) {
   209  	for i := range m.levels {
   210  		switch e, err := m.levels[i].Get(ctx, k); {
   211  		case err == nil:
   212  			return e, nil
   213  		case !errors.Is(err, ErrNotFound):
   214  			return nil, err
   215  		}
   216  	}
   217  	return nil, ErrNotFound
   218  }
   219  
   220  // Del deletes an entry from the cache.
   221  func (m *multiLevel) Del(ctx context.Context, k Key) error {
   222  	for i := range m.levels {
   223  		if err := m.levels[i].Del(ctx, k); err != nil {
   224  			return err
   225  		}
   226  	}
   227  	return nil
   228  }
   229  
   230  // contextLevel provides a context/request level cache implementation.
   231  type contextLevel struct{}
   232  
   233  // Get gets an entry from the cache.
   234  func (*contextLevel) Get(ctx context.Context, k Key) (*Entry, error) {
   235  	c, ok := FromContext(ctx)
   236  	if !ok {
   237  		return nil, ErrNotFound
   238  	}
   239  	return c.Get(ctx, k)
   240  }
   241  
   242  // Add adds the entry to the cache.
   243  func (*contextLevel) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error {
   244  	c, ok := FromContext(ctx)
   245  	if !ok {
   246  		return nil
   247  	}
   248  	return c.Add(ctx, k, e, ttl)
   249  }
   250  
   251  // Del deletes an entry from the cache.
   252  func (*contextLevel) Del(ctx context.Context, k Key) error {
   253  	c, ok := FromContext(ctx)
   254  	if !ok {
   255  		return nil
   256  	}
   257  	return c.Del(ctx, k)
   258  }