github.com/woocoos/entcache@v0.0.0-20231206055445-856f0148efa5/driver.go (about) 1 package entcache 2 3 import ( 4 "context" 5 stdsql "database/sql" 6 "database/sql/driver" 7 "entgo.io/ent/dialect" 8 "entgo.io/ent/dialect/sql" 9 "errors" 10 "fmt" 11 "github.com/tsingsun/woocoo/pkg/cache" 12 "github.com/tsingsun/woocoo/pkg/cache/lfu" 13 "github.com/tsingsun/woocoo/pkg/conf" 14 "github.com/tsingsun/woocoo/pkg/log" 15 "strings" 16 "sync/atomic" 17 "time" 18 _ "unsafe" 19 ) 20 21 //go:linkname convertAssign database/sql.convertAssign 22 func convertAssign(dest, src any) error 23 24 const ( 25 defaultDriverName = "default" 26 defaultGCInterval = time.Hour 27 ) 28 29 var ( 30 // errSkip tells the driver to skip cache layer. 31 errSkip = errors.New("entcache: skip cache") 32 33 driverManager = make(map[string]*Driver) 34 logger = log.Component("entcache") 35 ) 36 37 type ( 38 // A Driver is a SQL cached client. Users should use the 39 // constructor below for creating a new driver. 40 Driver struct { 41 *Config 42 dialect.Driver 43 stats Stats 44 45 Hash func(query string, args []any) (Key, error) 46 } 47 // Stats represent the cache statistics of the driver. 48 Stats struct { 49 Gets uint64 50 Hits uint64 51 Errors uint64 52 } 53 ) 54 55 // NewDriver wraps the given driver with a caching layer. 56 func NewDriver(drv dialect.Driver, opts ...Option) *Driver { 57 options := &Config{ 58 Name: defaultDriverName, 59 GCInterval: defaultGCInterval, 60 KeyQueryTTL: defaultGCInterval, 61 } 62 for _, opt := range opts { 63 opt(options) 64 } 65 var d *Driver 66 d, ok := driverManager[options.Name] 67 if !ok { 68 d = &Driver{} 69 driverManager[options.Name] = d 70 } 71 d.Config = options 72 if d.Config.Cache == nil { 73 if d.Config.StoreKey != "" { 74 var err error 75 d.Cache, err = cache.GetCache(d.Config.StoreKey) 76 if err != nil { 77 panic(err) 78 } 79 } else { 80 cnf := conf.NewFromStringMap(map[string]any{ 81 "size": 10000, 82 }) 83 if d.Config.HashQueryTTL > 0 { 84 cnf.Parser().Set("ttl", d.Config.HashQueryTTL) 85 } 86 c, err := lfu.NewTinyLFU(cnf) 87 if err != nil { 88 panic(err) 89 } 90 d.Cache = c 91 } 92 } 93 d.Driver = drv 94 d.Hash = DefaultHash 95 if d.ChangeSet == nil { 96 d.ChangeSet = NewChangeSet(d.GCInterval) 97 } 98 return d 99 } 100 101 // Query implements the Querier interface for the driver. It falls back to the 102 // underlying wrapped driver in case of caching error. 103 // 104 // Note that the driver does not synchronize identical queries that are executed 105 // concurrently. Hence, if 2 identical queries are executed at the ~same time, and 106 // there is no cache entry for them, the driver will execute both of them and the 107 // last successful one will be stored in the cache. 108 func (d *Driver) Query(ctx context.Context, query string, args, v any) error { 109 // Check if the given statement looks like a standard Ent query (e.g. SELECT). 110 // Custom queries (e.g. CTE) or statements that are prefixed with comments are 111 // not supported. This check is mainly necessary, because PostgreSQL and SQLite 112 // may execute an insert statement like "INSERT ... RETURNING" using Driver.Query. 113 if !strings.HasPrefix(query, "SELECT") && !strings.HasPrefix(query, "select") { 114 return d.Driver.Query(ctx, query, args, v) 115 } 116 vr, ok := v.(*sql.Rows) 117 if !ok { 118 return fmt.Errorf("entcache: invalid type %T. expect *sql.Rows", v) 119 } 120 argv, ok := args.([]any) 121 if !ok { 122 return fmt.Errorf("entcache: invalid type %T. expect []interface{} for args", args) 123 } 124 opts, err := d.optionsFromContext(ctx, query, argv) 125 if err != nil { 126 return d.Driver.Query(ctx, query, args, v) 127 } 128 atomic.AddUint64(&d.stats.Gets, 1) 129 var e Entry 130 if opts.evict { 131 err = cache.ErrCacheMiss 132 } else { 133 err = d.Cache.Get(ctx, string(opts.key), &e, cache.WithSkip(opts.skipMode)) 134 } 135 switch { 136 case err == nil: 137 atomic.AddUint64(&d.stats.Hits, 1) 138 vr.ColumnScanner = &repeater{columns: e.Columns, values: e.Values} 139 case errors.Is(err, cache.ErrCacheMiss): 140 if err := d.Driver.Query(ctx, query, args, vr); err != nil { 141 return err 142 } 143 vr.ColumnScanner = &recorder{ 144 ColumnScanner: vr.ColumnScanner, 145 onClose: func(columns []string, values [][]driver.Value) { 146 err := d.Cache.Set(ctx, string(opts.key), &Entry{Columns: columns, Values: values}, 147 cache.WithTTL(opts.ttl), cache.WithSkip(opts.skipMode), 148 ) 149 if err != nil { 150 atomic.AddUint64(&d.stats.Errors, 1) 151 logger.Warn(fmt.Sprintf("entcache: failed storing entry %v in cache: %v", opts.key, err)) 152 } 153 }, 154 } 155 default: 156 return d.Driver.Query(ctx, query, args, v) 157 } 158 return nil 159 } 160 161 // optionsFromContext returns the injected options from the context, or its default value. 162 // Note that the key in the context is an entry key, and will replace by hashed query key, that will improve the cache hit rate. 163 func (d *Driver) optionsFromContext(ctx context.Context, query string, args []any) (ctxOptions, error) { 164 var opts ctxOptions 165 if c, ok := ctx.Value(ctxOptionsKey).(*ctxOptions); ok { 166 opts = *c 167 if c.key != "" { 168 c.key = "" // clear it for eager loading. 169 } 170 } 171 key, err := d.Hash(query, args) 172 if err != nil { 173 return opts, errSkip 174 } 175 switch { 176 case opts.ref && opts.key != "": 177 if t, ok := d.ChangeSet.Load(opts.key); ok { 178 rt, loaded := d.ChangeSet.LoadOrStoreRef(key) 179 // the first query in the entity changed period, evict the cache; 180 // if the new entity changed happen after the previous query, evict the cache 181 opts.evict = !loaded || t.After(rt) 182 } else if _, ok := d.ChangeSet.LoadRef(key); ok { 183 opts.evict = true 184 d.ChangeSet.DeleteRef(key) 185 } 186 if opts.ttl == 0 { 187 opts.ttl = d.KeyQueryTTL 188 } 189 case opts.key == "": 190 if opts.ttl == 0 { 191 opts.ttl = d.HashQueryTTL 192 } 193 case opts.key != "": 194 if _, ok := d.ChangeSet.Load(opts.key); ok { 195 opts.evict = true 196 d.ChangeSet.Delete(opts.key) 197 } 198 if opts.ttl == 0 { 199 opts.ttl = d.KeyQueryTTL 200 } 201 } 202 // use hashed key as the cache key 203 opts.key = key 204 if d.CachePrefix != "" { 205 opts.key = Key(d.CachePrefix) + opts.key 206 } 207 if opts.skipMode == cache.SkipCache { 208 return opts, errSkip 209 } 210 return opts, nil 211 } 212 213 // rawCopy copies the driver values by implementing 214 // the sql.Scanner interface. 215 type rawCopy struct { 216 values []driver.Value 217 } 218 219 func (c *rawCopy) Scan(src interface{}) error { 220 if b, ok := src.([]byte); ok { 221 b1 := make([]byte, len(b)) 222 copy(b1, b) 223 src = b1 224 } 225 c.values[0] = src 226 c.values = c.values[1:] 227 return nil 228 } 229 230 // recorder represents an sql.Rows recorder that implements 231 // the entgo.io/ent/dialect/sql.ColumnScanner interface. 232 type recorder struct { 233 sql.ColumnScanner 234 values [][]driver.Value 235 columns []string 236 done bool 237 onClose func([]string, [][]driver.Value) 238 } 239 240 // Next wraps the underlying Next method 241 func (r *recorder) Next() bool { 242 hasNext := r.ColumnScanner.Next() 243 r.done = !hasNext 244 return hasNext 245 } 246 247 // Scan copies database values for future use (by the repeater) 248 // and assign them to the given destinations using the standard 249 // database/sql.convertAssign function. 250 func (r *recorder) Scan(dest ...any) error { 251 values := make([]driver.Value, len(dest)) 252 args := make([]any, len(dest)) 253 c := &rawCopy{values: values} 254 for i := range args { 255 args[i] = c 256 } 257 if err := r.ColumnScanner.Scan(args...); err != nil { 258 return err 259 } 260 for i := range values { 261 if err := convertAssign(dest[i], values[i]); err != nil { 262 return err 263 } 264 } 265 r.values = append(r.values, values) 266 return nil 267 } 268 269 // Columns wraps the underlying Column method and stores it in the recorder state. 270 // The repeater.Columns cannot be called if the recorder method was not called before. 271 // That means, raw scanning should be identical for identical queries. 272 func (r *recorder) Columns() ([]string, error) { 273 columns, err := r.ColumnScanner.Columns() 274 if err != nil { 275 return nil, err 276 } 277 r.columns = columns 278 return columns, nil 279 } 280 281 func (r *recorder) Close() error { 282 if err := r.ColumnScanner.Close(); err != nil { 283 return err 284 } 285 // If we did not encounter any error during iteration, 286 // and we scanned all rows, we store it on cache. 287 if err := r.ColumnScanner.Err(); err == nil || r.done { 288 r.onClose(r.columns, r.values) 289 } 290 return nil 291 } 292 293 // repeater repeats columns scanning from cache history. 294 type repeater struct { 295 columns []string 296 values [][]driver.Value 297 } 298 299 func (*repeater) Close() error { 300 return nil 301 } 302 func (*repeater) ColumnTypes() ([]*stdsql.ColumnType, error) { 303 return nil, fmt.Errorf("entcache.ColumnTypes is not supported") 304 } 305 func (r *repeater) Columns() ([]string, error) { 306 return r.columns, nil 307 } 308 func (*repeater) Err() error { 309 return nil 310 } 311 func (r *repeater) Next() bool { 312 return len(r.values) > 0 313 } 314 315 func (r *repeater) NextResultSet() bool { 316 return len(r.values) > 0 317 } 318 319 func (r *repeater) Scan(dest ...any) error { 320 if !r.Next() { 321 return stdsql.ErrNoRows 322 } 323 for i, src := range r.values[0] { 324 if err := convertAssign(dest[i], src); err != nil { 325 return err 326 } 327 } 328 r.values = r.values[1:] 329 return nil 330 }