ariga.io/entcache@v0.1.1-0.20230620164151-0eb723a11c40/driver.go (about) 1 package entcache 2 3 import ( 4 "context" 5 stdsql "database/sql" 6 "database/sql/driver" 7 "errors" 8 "fmt" 9 "strings" 10 "sync/atomic" 11 "time" 12 _ "unsafe" 13 14 "entgo.io/ent/dialect" 15 "entgo.io/ent/dialect/sql" 16 "github.com/mitchellh/hashstructure/v2" 17 ) 18 19 type ( 20 // Options wraps the basic configuration cache options. 21 Options struct { 22 // TTL defines the period of time that an Entry 23 // is valid in the cache. 24 TTL time.Duration 25 26 // Cache defines the GetAddDeleter (cache implementation) 27 // for holding the cache entries. If no cache implementation 28 // was provided, an LRU cache with no limit is used. 29 Cache AddGetDeleter 30 31 // Hash defines an optional Hash function for converting 32 // a query and its arguments to a cache key. If no Hash 33 // function was provided, the DefaultHash is used. 34 Hash func(query string, args []any) (Key, error) 35 36 // Logf function. If provided, the Driver will call it with 37 // errors that can not be handled. 38 Log func(...any) 39 } 40 41 // Option allows configuring the cache 42 // driver using functional options. 43 Option func(*Options) 44 45 // A Driver is an SQL cached client. Users should use the 46 // constructor below for creating new driver. 47 Driver struct { 48 dialect.Driver 49 *Options 50 stats Stats 51 } 52 ) 53 54 // NewDriver returns a new Driver an existing driver and optional 55 // configuration functions. For example: 56 // 57 // entcache.NewDriver( 58 // drv, 59 // entcache.TTL(time.Minute), 60 // entcache.Levels( 61 // NewLRU(256), 62 // NewRedis(redis.NewClient(&redis.Options{ 63 // Addr: ":6379", 64 // })), 65 // ) 66 // ) 67 func NewDriver(drv dialect.Driver, opts ...Option) *Driver { 68 options := &Options{Hash: DefaultHash, Cache: NewLRU(0)} 69 for _, opt := range opts { 70 opt(options) 71 } 72 return &Driver{ 73 Driver: drv, 74 Options: options, 75 } 76 } 77 78 // TTL configures the period of time that an Entry 79 // is valid in the cache. 80 func TTL(ttl time.Duration) Option { 81 return func(o *Options) { 82 o.TTL = ttl 83 } 84 } 85 86 // Hash configures an optional Hash function for 87 // converting a query and its arguments to a cache key. 88 func Hash(hash func(query string, args []any) (Key, error)) Option { 89 return func(o *Options) { 90 o.Hash = hash 91 } 92 } 93 94 // Levels configures the Driver to work with the given cache levels. 95 // For example, in process LRU cache and a remote Redis cache. 96 func Levels(levels ...AddGetDeleter) Option { 97 return func(o *Options) { 98 if len(levels) == 1 { 99 o.Cache = levels[0] 100 } else { 101 o.Cache = &multiLevel{levels: levels} 102 } 103 } 104 } 105 106 // ContextLevel configures the driver to work with context/request level cache. 107 // Users that use this option, should wraps the *http.Request context with the 108 // cache value as follows: 109 // 110 // ctx = entcache.NewContext(ctx) 111 // 112 // ctx = entcache.NewContext(ctx, entcache.NewLRU(128)) 113 func ContextLevel() Option { 114 return func(o *Options) { 115 o.Cache = &contextLevel{} 116 } 117 } 118 119 // Query implements the Querier interface for the driver. It falls back to the 120 // underlying wrapped driver in case of caching error. 121 // 122 // Note that, the driver does not synchronize identical queries that are executed 123 // concurrently. Hence, if 2 identical queries are executed at the ~same time, and 124 // there is no cache entry for them, the driver will execute both of them and the 125 // last successful one will be stored in the cache. 126 func (d *Driver) Query(ctx context.Context, query string, args, v any) error { 127 // Check if the given statement looks like a standard Ent query (e.g. SELECT). 128 // Custom queries (e.g. CTE) or statements that are prefixed with comments are 129 // not supported. This check is mainly necessary, because PostgreSQL and SQLite 130 // may execute insert statement like "INSERT ... RETURNING" using Driver.Query. 131 if !strings.HasPrefix(query, "SELECT") && !strings.HasPrefix(query, "select") { 132 return d.Driver.Query(ctx, query, args, v) 133 } 134 vr, ok := v.(*sql.Rows) 135 if !ok { 136 return fmt.Errorf("entcache: invalid type %T. expect *sql.Rows", v) 137 } 138 argv, ok := args.([]any) 139 if !ok { 140 return fmt.Errorf("entcache: invalid type %T. expect []interface{} for args", args) 141 } 142 opts, err := d.optionsFromContext(ctx, query, argv) 143 if err != nil { 144 return d.Driver.Query(ctx, query, args, v) 145 } 146 atomic.AddUint64(&d.stats.Gets, 1) 147 switch e, err := d.Cache.Get(ctx, opts.key); { 148 case err == nil: 149 atomic.AddUint64(&d.stats.Hits, 1) 150 vr.ColumnScanner = &repeater{columns: e.Columns, values: e.Values} 151 case err == ErrNotFound: 152 if err := d.Driver.Query(ctx, query, args, vr); err != nil { 153 return err 154 } 155 vr.ColumnScanner = &recorder{ 156 ColumnScanner: vr.ColumnScanner, 157 onClose: func(columns []string, values [][]driver.Value) { 158 err := d.Cache.Add(ctx, opts.key, &Entry{Columns: columns, Values: values}, opts.ttl) 159 if err != nil && d.Log != nil { 160 atomic.AddUint64(&d.stats.Errors, 1) 161 d.Log(fmt.Sprintf("entcache: failed storing entry %v in cache: %v", opts.key, err)) 162 } 163 }, 164 } 165 default: 166 return d.Driver.Query(ctx, query, args, v) 167 } 168 return nil 169 } 170 171 // Stats returns a copy of the cache statistics. 172 func (d *Driver) Stats() Stats { 173 return Stats{ 174 Gets: atomic.LoadUint64(&d.stats.Gets), 175 Hits: atomic.LoadUint64(&d.stats.Hits), 176 Errors: atomic.LoadUint64(&d.stats.Errors), 177 } 178 } 179 180 // QueryContext calls QueryContext of the underlying driver, or fails if it is not supported. 181 // Note, this method is not part of the caching layer since Ent does not use it by default. 182 func (d *Driver) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { 183 drv, ok := d.Driver.(interface { 184 QueryContext(context.Context, string, ...any) (*sql.Rows, error) 185 }) 186 if !ok { 187 return nil, fmt.Errorf("Driver.QueryContext is not supported") 188 } 189 return drv.QueryContext(ctx, query, args...) 190 } 191 192 // ExecContext calls ExecContext of the underlying driver, or fails if it is not supported. 193 func (d *Driver) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { 194 drv, ok := d.Driver.(interface { 195 ExecContext(context.Context, string, ...any) (sql.Result, error) 196 }) 197 if !ok { 198 return nil, fmt.Errorf("Driver.ExecContext is not supported") 199 } 200 return drv.ExecContext(ctx, query, args...) 201 } 202 203 // errSkip tells the driver to skip cache layer. 204 var errSkip = errors.New("entcache: skip cache") 205 206 // optionsFromContext returns the injected options from the context, or its default value. 207 func (d *Driver) optionsFromContext(ctx context.Context, query string, args []any) (ctxOptions, error) { 208 var opts ctxOptions 209 if c, ok := ctx.Value(ctxOptionsKey).(*ctxOptions); ok { 210 opts = *c 211 } 212 if opts.key == nil { 213 key, err := d.Hash(query, args) 214 if err != nil { 215 return opts, errSkip 216 } 217 opts.key = key 218 } 219 if opts.ttl == 0 { 220 opts.ttl = d.TTL 221 } 222 if opts.evict { 223 if err := d.Cache.Del(ctx, opts.key); err != nil { 224 return opts, err 225 } 226 } 227 if opts.skip { 228 return opts, errSkip 229 } 230 return opts, nil 231 } 232 233 // DefaultHash provides the default implementation for converting 234 // a query and its argument to a cache key. 235 func DefaultHash(query string, args []any) (Key, error) { 236 key, err := hashstructure.Hash(struct { 237 Q string 238 A []any 239 }{ 240 Q: query, 241 A: args, 242 }, hashstructure.FormatV2, nil) 243 if err != nil { 244 return nil, err 245 } 246 return key, nil 247 } 248 249 // Stats represents the cache statistics of the driver. 250 type Stats struct { 251 Gets uint64 252 Hits uint64 253 Errors uint64 254 } 255 256 // rawCopy copies the driver values by implementing 257 // the sql.Scanner interface. 258 type rawCopy struct { 259 values []driver.Value 260 } 261 262 func (c *rawCopy) Scan(src interface{}) error { 263 if b, ok := src.([]byte); ok { 264 b1 := make([]byte, len(b)) 265 copy(b1, b) 266 src = b1 267 } 268 c.values[0] = src 269 c.values = c.values[1:] 270 return nil 271 } 272 273 // recorder represents an sql.Rows recorder that implements 274 // the entgo.io/ent/dialect/sql.ColumnScanner interface. 275 type recorder struct { 276 sql.ColumnScanner 277 values [][]driver.Value 278 columns []string 279 done bool 280 onClose func([]string, [][]driver.Value) 281 } 282 283 // Next wraps the underlying Next method 284 func (r *recorder) Next() bool { 285 hasNext := r.ColumnScanner.Next() 286 r.done = !hasNext 287 return hasNext 288 } 289 290 // Scan copies database values for future use (by the repeater) 291 // and assign them to the given destinations using the standard 292 // database/sql.convertAssign function. 293 func (r *recorder) Scan(dest ...any) error { 294 values := make([]driver.Value, len(dest)) 295 args := make([]any, len(dest)) 296 c := &rawCopy{values: values} 297 for i := range args { 298 args[i] = c 299 } 300 if err := r.ColumnScanner.Scan(args...); err != nil { 301 return err 302 } 303 for i := range values { 304 if err := convertAssign(dest[i], values[i]); err != nil { 305 return err 306 } 307 } 308 r.values = append(r.values, values) 309 return nil 310 } 311 312 // Columns wraps the underlying Column method and stores it in the recorder state. 313 // The repeater.Columns cannot be called if the recorder method was not called before. 314 // That means, raw scanning should be identical for identical queries. 315 func (r *recorder) Columns() ([]string, error) { 316 columns, err := r.ColumnScanner.Columns() 317 if err != nil { 318 return nil, err 319 } 320 r.columns = columns 321 return columns, nil 322 } 323 324 func (r *recorder) Close() error { 325 if err := r.ColumnScanner.Close(); err != nil { 326 return err 327 } 328 // If we did not encounter any error during iteration, 329 // and we scanned all rows, we store it on cache. 330 if err := r.ColumnScanner.Err(); err == nil || r.done { 331 r.onClose(r.columns, r.values) 332 } 333 return nil 334 } 335 336 // repeater repeats columns scanning from cache history. 337 type repeater struct { 338 columns []string 339 values [][]driver.Value 340 } 341 342 func (*repeater) Close() error { 343 return nil 344 } 345 func (*repeater) ColumnTypes() ([]*stdsql.ColumnType, error) { 346 return nil, fmt.Errorf("entcache.ColumnTypes is not supported") 347 } 348 func (r *repeater) Columns() ([]string, error) { 349 return r.columns, nil 350 } 351 func (*repeater) Err() error { 352 return nil 353 } 354 func (r *repeater) Next() bool { 355 return len(r.values) > 0 356 } 357 func (r *repeater) NextResultSet() bool { 358 return len(r.values) > 0 359 } 360 361 func (r *repeater) Scan(dest ...any) error { 362 if !r.Next() { 363 return stdsql.ErrNoRows 364 } 365 for i, src := range r.values[0] { 366 if err := convertAssign(dest[i], src); err != nil { 367 return err 368 } 369 } 370 r.values = r.values[1:] 371 return nil 372 } 373 374 //go:linkname convertAssign database/sql.convertAssign 375 func convertAssign(dest, src any) error