github.com/DeltaLaboratory/entcache@v0.1.1/driver.go (about) 1 package entcache 2 3 import ( 4 "context" 5 stdsql "database/sql" 6 "database/sql/driver" 7 "errors" 8 "fmt" 9 "strings" 10 "sync/atomic" 11 "time" 12 _ "unsafe" 13 14 "entgo.io/ent/dialect" 15 "entgo.io/ent/dialect/sql" 16 "github.com/mitchellh/hashstructure/v2" 17 ) 18 19 type ( 20 // Options wrap the basic configuration cache options. 21 Options struct { 22 // TTL defines the period of time that an Entry 23 // is valid in the cache. 24 TTL time.Duration 25 26 // Cache defines the GetAddDeleter (cache implementation) 27 // for holding the cache entries. If no cache implementation 28 // was provided, an LRU cache with no limit is used. 29 Cache AddGetDeleter 30 31 // Hash defines an optional Hash function for converting 32 // a query and its arguments to a cache key. If no Hash 33 // function was provided, the DefaultHash is used. 34 Hash func(query string, args []any) (Key, error) 35 36 // Logf function. If provided, the Driver will call it with 37 // errors that cannot be handled. 38 Log func(...any) 39 } 40 41 // Option allows configuring the cache 42 // driver using functional options. 43 Option func(*Options) 44 45 // A Driver is an SQL-cached client. 46 // Users should use the constructor below for creating a new driver. 47 Driver struct { 48 dialect.Driver 49 *Options 50 stats Stats 51 } 52 ) 53 54 // NewDriver returns a new Driver an existing driver and optional 55 // configuration functions. 56 // For example, 57 // 58 // entcache.NewDriver( 59 // drv, 60 // entcache.TTL(time.Minute), 61 // entcache.Levels( 62 // NewLRU(256), 63 // NewRedis(redis.NewClient(&redis.Options{ 64 // Addr: ":6379", 65 // })), 66 // ) 67 // ) 68 func NewDriver(drv dialect.Driver, opts ...Option) *Driver { 69 options := &Options{Hash: DefaultHash, Cache: NewLRU(0)} 70 for _, opt := range opts { 71 opt(options) 72 } 73 return &Driver{ 74 Driver: drv, 75 Options: options, 76 } 77 } 78 79 // TTL configures the period of time that an Entry 80 // is valid in the cache. 81 func TTL(ttl time.Duration) Option { 82 return func(o *Options) { 83 o.TTL = ttl 84 } 85 } 86 87 // Hash configures an optional Hash function for 88 // converting a query and its arguments to a cache key. 89 func Hash(hash func(query string, args []any) (Key, error)) Option { 90 return func(o *Options) { 91 o.Hash = hash 92 } 93 } 94 95 // Levels configure the Driver to work with the given cache levels. 96 // For example, in process LRU cache and a remote Redis cache. 97 func Levels(levels ...AddGetDeleter) Option { 98 return func(o *Options) { 99 if len(levels) == 1 { 100 o.Cache = levels[0] 101 } else { 102 o.Cache = &multiLevel{levels: levels} 103 } 104 } 105 } 106 107 // ContextLevel configures the driver to work with context/request level cache. 108 // Users that use this option should wrap the *http.Request context with the 109 // cache value as follows: 110 // 111 // ctx = entcache.NewContext(ctx) 112 // 113 // ctx = entcache.NewContext(ctx, entcache.NewLRU(128)) 114 func ContextLevel() Option { 115 return func(o *Options) { 116 o.Cache = &contextLevel{} 117 } 118 } 119 120 // Query implements the Querier interface for the driver. It falls back to the 121 // underlying wrapped driver in case of caching error. 122 // 123 // Note that the driver does not synchronize identical queries that are executed 124 // concurrently. Hence, if two identical queries are executed at the ~same time, and 125 // there is no cache entry for them, the driver will execute both of them and the 126 // last successful one will be stored in the cache. 127 func (d *Driver) Query(ctx context.Context, query string, args, v any) error { 128 // Check if the given statement looks like a standard Ent query (e.g., SELECT). 129 // Custom queries (e.g., CTE) or statements that are prefixed with comments are not supported. 130 // This check is mainly necessary because PostgreSQL and SQLite 131 // may execute an insert statement like "INSERT ... RETURNING" using Driver.Query. 132 if !strings.HasPrefix(query, "SELECT") && !strings.HasPrefix(query, "select") { 133 return d.Driver.Query(ctx, query, args, v) 134 } 135 vr, ok := v.(*sql.Rows) 136 if !ok { 137 return fmt.Errorf("entcache: invalid type %T. expect *sql.Rows", v) 138 } 139 argv, ok := args.([]any) 140 if !ok { 141 return fmt.Errorf("entcache: invalid type %T. expect []any for args", args) 142 } 143 opts, err := d.optionsFromContext(ctx, query, argv) 144 if err != nil { 145 return d.Driver.Query(ctx, query, args, v) 146 } 147 atomic.AddUint64(&d.stats.Gets, 1) 148 switch e, err := d.Cache.Get(ctx, opts.key); { 149 case err == nil: 150 atomic.AddUint64(&d.stats.Hits, 1) 151 vr.ColumnScanner = &repeater{columns: e.Columns, values: e.Values} 152 case errors.Is(err, ErrNotFound): 153 if err := d.Driver.Query(ctx, query, args, vr); err != nil { 154 return err 155 } 156 vr.ColumnScanner = &recorder{ 157 ColumnScanner: vr.ColumnScanner, 158 onClose: func(columns []string, values [][]driver.Value) { 159 err := d.Cache.Add(ctx, opts.key, &Entry{Columns: columns, Values: values}, opts.ttl) 160 if err != nil && d.Log != nil { 161 atomic.AddUint64(&d.stats.Errors, 1) 162 d.Log(fmt.Sprintf("entcache: failed storing entry %v in cache: %v", opts.key, err)) 163 } 164 }, 165 } 166 default: 167 return d.Driver.Query(ctx, query, args, v) 168 } 169 return nil 170 } 171 172 // Stats return a copy of the cache statistics. 173 func (d *Driver) Stats() Stats { 174 return Stats{ 175 Gets: atomic.LoadUint64(&d.stats.Gets), 176 Hits: atomic.LoadUint64(&d.stats.Hits), 177 Errors: atomic.LoadUint64(&d.stats.Errors), 178 } 179 } 180 181 // QueryContext calls QueryContext of the underlying driver, or fails if it is not supported. 182 // Note, this method is not part of the caching layer since Ent does not use it by default. 183 func (d *Driver) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { 184 drv, ok := d.Driver.(interface { 185 QueryContext(context.Context, string, ...any) (*sql.Rows, error) 186 }) 187 if !ok { 188 return nil, fmt.Errorf("Driver.QueryContext is not supported") 189 } 190 return drv.QueryContext(ctx, query, args...) 191 } 192 193 // ExecContext calls ExecContext of the underlying driver, or fails if it is not supported. 194 func (d *Driver) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { 195 drv, ok := d.Driver.(interface { 196 ExecContext(context.Context, string, ...any) (sql.Result, error) 197 }) 198 if !ok { 199 return nil, fmt.Errorf("Driver.ExecContext is not supported") 200 } 201 return drv.ExecContext(ctx, query, args...) 202 } 203 204 // errSkip tells the driver to skip cache layer. 205 var errSkip = errors.New("entcache: skip cache") 206 207 // optionsFromContext returns the injected options from the context, or its default value. 208 func (d *Driver) optionsFromContext(ctx context.Context, query string, args []any) (ctxOptions, error) { 209 var opts ctxOptions 210 if c, ok := ctx.Value(ctxOptionsKey).(*ctxOptions); ok { 211 opts = *c 212 } 213 if opts.key == nil { 214 key, err := d.Hash(query, args) 215 if err != nil { 216 return opts, errSkip 217 } 218 opts.key = key 219 } 220 if opts.ttl == 0 { 221 opts.ttl = d.TTL 222 } 223 if opts.evict { 224 if err := d.Cache.Del(ctx, opts.key); err != nil { 225 return opts, err 226 } 227 } 228 if opts.skip { 229 return opts, errSkip 230 } 231 return opts, nil 232 } 233 234 // DefaultHash provides the default implementation for converting 235 // a query and its argument to a cache key. 236 func DefaultHash(query string, args []any) (Key, error) { 237 key, err := hashstructure.Hash(struct { 238 Q string 239 A []any 240 }{ 241 Q: query, 242 A: args, 243 }, hashstructure.FormatV2, nil) 244 if err != nil { 245 return nil, err 246 } 247 return key, nil 248 } 249 250 // Stats represent the cache statistics of the driver. 251 type Stats struct { 252 Gets uint64 253 Hits uint64 254 Errors uint64 255 } 256 257 // rawCopy copies the driver values by implementing 258 // the sql.Scanner interface. 259 type rawCopy struct { 260 values []driver.Value 261 } 262 263 func (c *rawCopy) Scan(src interface{}) error { 264 if b, ok := src.([]byte); ok { 265 b1 := make([]byte, len(b)) 266 copy(b1, b) 267 src = b1 268 } 269 c.values[0] = src 270 c.values = c.values[1:] 271 return nil 272 } 273 274 // recorder represents a sql.Rows recorder that implements 275 // the entgo.io/ent/dialect/sql.ColumnScanner interface. 276 type recorder struct { 277 sql.ColumnScanner 278 values [][]driver.Value 279 columns []string 280 done bool 281 onClose func([]string, [][]driver.Value) 282 } 283 284 // Next wraps the underlying Next method 285 func (r *recorder) Next() bool { 286 hasNext := r.ColumnScanner.Next() 287 r.done = !hasNext 288 return hasNext 289 } 290 291 // Scan copies database values for future use (by the repeater) 292 // and assign them to the given destinations using the standard 293 // database/sql.convertAssign function. 294 func (r *recorder) Scan(dest ...any) error { 295 values := make([]driver.Value, len(dest)) 296 args := make([]any, len(dest)) 297 c := &rawCopy{values: values} 298 for i := range args { 299 args[i] = c 300 } 301 if err := r.ColumnScanner.Scan(args...); err != nil { 302 return err 303 } 304 for i := range values { 305 if err := convertAssign(dest[i], values[i]); err != nil { 306 return err 307 } 308 } 309 r.values = append(r.values, values) 310 return nil 311 } 312 313 // Columns wrap the underlying Column method and store it in the recorder state. 314 // The repeater.Columns cannot be called if the recorder method was not called before. 315 // That means raw scanning should be identical for identical queries. 316 func (r *recorder) Columns() ([]string, error) { 317 columns, err := r.ColumnScanner.Columns() 318 if err != nil { 319 return nil, err 320 } 321 r.columns = columns 322 return columns, nil 323 } 324 325 func (r *recorder) Close() error { 326 if err := r.ColumnScanner.Close(); err != nil { 327 return err 328 } 329 // If we did not encounter any error during iteration, 330 // and we scanned all rows, we store it on cache. 331 if err := r.ColumnScanner.Err(); err == nil || r.done { 332 r.onClose(r.columns, r.values) 333 } 334 return nil 335 } 336 337 // repeater repeats columns scanning from cache history. 338 type repeater struct { 339 columns []string 340 values [][]driver.Value 341 } 342 343 func (*repeater) Close() error { 344 return nil 345 } 346 func (*repeater) ColumnTypes() ([]*stdsql.ColumnType, error) { 347 return nil, fmt.Errorf("entcache.ColumnTypes is not supported") 348 } 349 func (r *repeater) Columns() ([]string, error) { 350 return r.columns, nil 351 } 352 func (*repeater) Err() error { 353 return nil 354 } 355 func (r *repeater) Next() bool { 356 return len(r.values) > 0 357 } 358 func (r *repeater) NextResultSet() bool { 359 return len(r.values) > 0 360 } 361 362 func (r *repeater) Scan(dest ...any) error { 363 if !r.Next() { 364 return stdsql.ErrNoRows 365 } 366 for i, src := range r.values[0] { 367 if err := convertAssign(dest[i], src); err != nil { 368 return err 369 } 370 } 371 r.values = r.values[1:] 372 return nil 373 } 374 375 //go:linkname convertAssign database/sql.convertAssign 376 func convertAssign(dest, src any) error