ariga.io/entcache@v0.1.1-0.20230620164151-0eb723a11c40/level.go (about) 1 package entcache 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql/driver" 7 "encoding/gob" 8 "errors" 9 "fmt" 10 "sync" 11 "time" 12 13 "github.com/golang/groupcache/lru" 14 "github.com/redis/go-redis/v9" 15 ) 16 17 type ( 18 // Entry defines an entry to store in a cache. 19 Entry struct { 20 Columns []string 21 Values [][]driver.Value 22 } 23 24 // A Key defines a comparable Go value. 25 // See http://golang.org/ref/spec#Comparison_operators 26 Key any 27 28 // AddGetDeleter defines the interface for getting, 29 // adding and deleting entries from the cache. 30 AddGetDeleter interface { 31 Del(context.Context, Key) error 32 Add(context.Context, Key, *Entry, time.Duration) error 33 Get(context.Context, Key) (*Entry, error) 34 } 35 ) 36 37 func init() { 38 // Register non builtin driver.Values. 39 gob.Register(time.Time{}) 40 } 41 42 // MarshalBinary implements the encoding.BinaryMarshaler interface. 43 func (e Entry) MarshalBinary() ([]byte, error) { 44 entry := struct { 45 C []string 46 V [][]driver.Value 47 }{ 48 C: e.Columns, 49 V: e.Values, 50 } 51 var buf bytes.Buffer 52 if err := gob.NewEncoder(&buf).Encode(entry); err != nil { 53 return nil, err 54 } 55 return buf.Bytes(), nil 56 } 57 58 // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. 59 func (e *Entry) UnmarshalBinary(buf []byte) error { 60 var entry struct { 61 C []string 62 V [][]driver.Value 63 } 64 if err := gob.NewDecoder(bytes.NewBuffer(buf)).Decode(&entry); err != nil { 65 return err 66 } 67 e.Values = entry.V 68 e.Columns = entry.C 69 return nil 70 } 71 72 // ErrNotFound is returned by Get when and Entry does not exist in the cache. 73 var ErrNotFound = errors.New("entcache: entry was not found") 74 75 type ( 76 // LRU provides an LRU cache that implements the AddGetter interface. 77 LRU struct { 78 mu sync.Mutex 79 *lru.Cache 80 } 81 // entry wraps the Entry with additional expiry information. 82 entry struct { 83 *Entry 84 expiry time.Time 85 } 86 ) 87 88 // NewLRU creates a new Cache. 89 // If maxEntries is zero, the cache has no limit. 90 func NewLRU(maxEntries int) *LRU { 91 return &LRU{ 92 Cache: lru.New(maxEntries), 93 } 94 } 95 96 // Add adds the entry to the cache. 97 func (l *LRU) Add(_ context.Context, k Key, e *Entry, ttl time.Duration) error { 98 l.mu.Lock() 99 defer l.mu.Unlock() 100 buf, err := e.MarshalBinary() 101 if err != nil { 102 return err 103 } 104 ne := &Entry{} 105 if err := ne.UnmarshalBinary(buf); err != nil { 106 return err 107 } 108 if ttl == 0 { 109 l.Cache.Add(k, ne) 110 } else { 111 l.Cache.Add(k, &entry{Entry: ne, expiry: time.Now().Add(ttl)}) 112 } 113 return nil 114 } 115 116 // Get gets an entry from the cache. 117 func (l *LRU) Get(_ context.Context, k Key) (*Entry, error) { 118 l.mu.Lock() 119 e, ok := l.Cache.Get(k) 120 l.mu.Unlock() 121 if !ok { 122 return nil, ErrNotFound 123 } 124 switch e := e.(type) { 125 case *Entry: 126 return e, nil 127 case *entry: 128 if time.Now().Before(e.expiry) { 129 return e.Entry, nil 130 } 131 l.mu.Lock() 132 l.Cache.Remove(k) 133 l.mu.Unlock() 134 return nil, ErrNotFound 135 default: 136 return nil, fmt.Errorf("entcache: unexpected entry type: %T", e) 137 } 138 } 139 140 // Del deletes an entry from the cache. 141 func (l *LRU) Del(_ context.Context, k Key) error { 142 l.mu.Lock() 143 l.Cache.Remove(k) 144 l.mu.Unlock() 145 return nil 146 } 147 148 // Redis provides a remote cache backed by Redis 149 // and implements the SetGetter interface. 150 type Redis struct { 151 c redis.Cmdable 152 } 153 154 // NewRedis returns a new Redis cache level from the given Redis connection. 155 // 156 // entcache.NewRedis(redis.NewClient(&redis.Options{ 157 // Addr: ":6379" 158 // })) 159 // 160 // entcache.NewRedis(redis.NewClusterClient(&redis.ClusterOptions{ 161 // Addrs: []string{":7000", ":7001", ":7002"}, 162 // })) 163 func NewRedis(c redis.Cmdable) *Redis { 164 return &Redis{c: c} 165 } 166 167 // Add adds the entry to the cache. 168 func (r *Redis) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error { 169 key := fmt.Sprint(k) 170 if key == "" { 171 return nil 172 } 173 buf, err := e.MarshalBinary() 174 if err != nil { 175 return err 176 } 177 if err := r.c.Set(ctx, key, buf, ttl).Err(); err != nil { 178 return err 179 } 180 return nil 181 } 182 183 // Get gets an entry from the cache. 184 func (r *Redis) Get(ctx context.Context, k Key) (*Entry, error) { 185 key := fmt.Sprint(k) 186 if key == "" { 187 return nil, ErrNotFound 188 } 189 buf, err := r.c.Get(ctx, key).Bytes() 190 if err != nil || len(buf) == 0 { 191 return nil, ErrNotFound 192 } 193 e := &Entry{} 194 if err := e.UnmarshalBinary(buf); err != nil { 195 return nil, err 196 } 197 return e, nil 198 } 199 200 // Del deletes an entry from the cache. 201 func (r *Redis) Del(ctx context.Context, k Key) error { 202 key := fmt.Sprint(k) 203 if key == "" { 204 return nil 205 } 206 return r.c.Del(ctx, key).Err() 207 } 208 209 // multiLevel provides a multi-level cache implementation. 210 type multiLevel struct { 211 levels []AddGetDeleter 212 } 213 214 // Add adds the entry to the cache. 215 func (m *multiLevel) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error { 216 for i := range m.levels { 217 if err := m.levels[i].Add(ctx, k, e, ttl); err != nil { 218 return err 219 } 220 } 221 return nil 222 } 223 224 // Get gets an entry from the cache. 225 func (m *multiLevel) Get(ctx context.Context, k Key) (*Entry, error) { 226 for i := range m.levels { 227 switch e, err := m.levels[i].Get(ctx, k); { 228 case err == nil: 229 return e, nil 230 case err != ErrNotFound: 231 return nil, err 232 } 233 } 234 return nil, ErrNotFound 235 } 236 237 // Del deletes an entry from the cache. 238 func (m *multiLevel) Del(ctx context.Context, k Key) error { 239 for i := range m.levels { 240 if err := m.levels[i].Del(ctx, k); err != nil { 241 return err 242 } 243 } 244 return nil 245 } 246 247 // contextLevel provides a context/request level cache implementation. 248 type contextLevel struct{} 249 250 // Get gets an entry from the cache. 251 func (*contextLevel) Get(ctx context.Context, k Key) (*Entry, error) { 252 c, ok := FromContext(ctx) 253 if !ok { 254 return nil, ErrNotFound 255 } 256 return c.Get(ctx, k) 257 } 258 259 // Add adds the entry to the cache. 260 func (*contextLevel) Add(ctx context.Context, k Key, e *Entry, ttl time.Duration) error { 261 c, ok := FromContext(ctx) 262 if !ok { 263 return nil 264 } 265 return c.Add(ctx, k, e, ttl) 266 } 267 268 // Del deletes an entry from the cache. 269 func (*contextLevel) Del(ctx context.Context, k Key) error { 270 c, ok := FromContext(ctx) 271 if !ok { 272 return nil 273 } 274 return c.Del(ctx, k) 275 }