github.com/woocoos/entcache@v0.0.0-20231206055445-856f0148efa5/driver.go (about)

     1  package entcache
     2  
     3  import (
     4  	"context"
     5  	stdsql "database/sql"
     6  	"database/sql/driver"
     7  	"entgo.io/ent/dialect"
     8  	"entgo.io/ent/dialect/sql"
     9  	"errors"
    10  	"fmt"
    11  	"github.com/tsingsun/woocoo/pkg/cache"
    12  	"github.com/tsingsun/woocoo/pkg/cache/lfu"
    13  	"github.com/tsingsun/woocoo/pkg/conf"
    14  	"github.com/tsingsun/woocoo/pkg/log"
    15  	"strings"
    16  	"sync/atomic"
    17  	"time"
    18  	_ "unsafe"
    19  )
    20  
    21  //go:linkname convertAssign database/sql.convertAssign
    22  func convertAssign(dest, src any) error
    23  
    24  const (
    25  	defaultDriverName = "default"
    26  	defaultGCInterval = time.Hour
    27  )
    28  
    29  var (
    30  	// errSkip tells the driver to skip cache layer.
    31  	errSkip = errors.New("entcache: skip cache")
    32  
    33  	driverManager = make(map[string]*Driver)
    34  	logger        = log.Component("entcache")
    35  )
    36  
    37  type (
    38  	// A Driver is a SQL cached client. Users should use the
    39  	// constructor below for creating a new driver.
    40  	Driver struct {
    41  		*Config
    42  		dialect.Driver
    43  		stats Stats
    44  
    45  		Hash func(query string, args []any) (Key, error)
    46  	}
    47  	// Stats represent the cache statistics of the driver.
    48  	Stats struct {
    49  		Gets   uint64
    50  		Hits   uint64
    51  		Errors uint64
    52  	}
    53  )
    54  
    55  // NewDriver wraps the given driver with a caching layer.
    56  func NewDriver(drv dialect.Driver, opts ...Option) *Driver {
    57  	options := &Config{
    58  		Name:        defaultDriverName,
    59  		GCInterval:  defaultGCInterval,
    60  		KeyQueryTTL: defaultGCInterval,
    61  	}
    62  	for _, opt := range opts {
    63  		opt(options)
    64  	}
    65  	var d *Driver
    66  	d, ok := driverManager[options.Name]
    67  	if !ok {
    68  		d = &Driver{}
    69  		driverManager[options.Name] = d
    70  	}
    71  	d.Config = options
    72  	if d.Config.Cache == nil {
    73  		if d.Config.StoreKey != "" {
    74  			var err error
    75  			d.Cache, err = cache.GetCache(d.Config.StoreKey)
    76  			if err != nil {
    77  				panic(err)
    78  			}
    79  		} else {
    80  			cnf := conf.NewFromStringMap(map[string]any{
    81  				"size": 10000,
    82  			})
    83  			if d.Config.HashQueryTTL > 0 {
    84  				cnf.Parser().Set("ttl", d.Config.HashQueryTTL)
    85  			}
    86  			c, err := lfu.NewTinyLFU(cnf)
    87  			if err != nil {
    88  				panic(err)
    89  			}
    90  			d.Cache = c
    91  		}
    92  	}
    93  	d.Driver = drv
    94  	d.Hash = DefaultHash
    95  	if d.ChangeSet == nil {
    96  		d.ChangeSet = NewChangeSet(d.GCInterval)
    97  	}
    98  	return d
    99  }
   100  
   101  // Query implements the Querier interface for the driver. It falls back to the
   102  // underlying wrapped driver in case of caching error.
   103  //
   104  // Note that the driver does not synchronize identical queries that are executed
   105  // concurrently. Hence, if 2 identical queries are executed at the ~same time, and
   106  // there is no cache entry for them, the driver will execute both of them and the
   107  // last successful one will be stored in the cache.
   108  func (d *Driver) Query(ctx context.Context, query string, args, v any) error {
   109  	// Check if the given statement looks like a standard Ent query (e.g. SELECT).
   110  	// Custom queries (e.g. CTE) or statements that are prefixed with comments are
   111  	// not supported. This check is mainly necessary, because PostgreSQL and SQLite
   112  	// may execute an insert statement like "INSERT ... RETURNING" using Driver.Query.
   113  	if !strings.HasPrefix(query, "SELECT") && !strings.HasPrefix(query, "select") {
   114  		return d.Driver.Query(ctx, query, args, v)
   115  	}
   116  	vr, ok := v.(*sql.Rows)
   117  	if !ok {
   118  		return fmt.Errorf("entcache: invalid type %T. expect *sql.Rows", v)
   119  	}
   120  	argv, ok := args.([]any)
   121  	if !ok {
   122  		return fmt.Errorf("entcache: invalid type %T. expect []interface{} for args", args)
   123  	}
   124  	opts, err := d.optionsFromContext(ctx, query, argv)
   125  	if err != nil {
   126  		return d.Driver.Query(ctx, query, args, v)
   127  	}
   128  	atomic.AddUint64(&d.stats.Gets, 1)
   129  	var e Entry
   130  	if opts.evict {
   131  		err = cache.ErrCacheMiss
   132  	} else {
   133  		err = d.Cache.Get(ctx, string(opts.key), &e, cache.WithSkip(opts.skipMode))
   134  	}
   135  	switch {
   136  	case err == nil:
   137  		atomic.AddUint64(&d.stats.Hits, 1)
   138  		vr.ColumnScanner = &repeater{columns: e.Columns, values: e.Values}
   139  	case errors.Is(err, cache.ErrCacheMiss):
   140  		if err := d.Driver.Query(ctx, query, args, vr); err != nil {
   141  			return err
   142  		}
   143  		vr.ColumnScanner = &recorder{
   144  			ColumnScanner: vr.ColumnScanner,
   145  			onClose: func(columns []string, values [][]driver.Value) {
   146  				err := d.Cache.Set(ctx, string(opts.key), &Entry{Columns: columns, Values: values},
   147  					cache.WithTTL(opts.ttl), cache.WithSkip(opts.skipMode),
   148  				)
   149  				if err != nil {
   150  					atomic.AddUint64(&d.stats.Errors, 1)
   151  					logger.Warn(fmt.Sprintf("entcache: failed storing entry %v in cache: %v", opts.key, err))
   152  				}
   153  			},
   154  		}
   155  	default:
   156  		return d.Driver.Query(ctx, query, args, v)
   157  	}
   158  	return nil
   159  }
   160  
   161  // optionsFromContext returns the injected options from the context, or its default value.
   162  // Note that the key in the context is an entry key, and will replace by hashed query key, that will improve the cache hit rate.
   163  func (d *Driver) optionsFromContext(ctx context.Context, query string, args []any) (ctxOptions, error) {
   164  	var opts ctxOptions
   165  	if c, ok := ctx.Value(ctxOptionsKey).(*ctxOptions); ok {
   166  		opts = *c
   167  		if c.key != "" {
   168  			c.key = "" // clear it for eager loading.
   169  		}
   170  	}
   171  	key, err := d.Hash(query, args)
   172  	if err != nil {
   173  		return opts, errSkip
   174  	}
   175  	switch {
   176  	case opts.ref && opts.key != "":
   177  		if t, ok := d.ChangeSet.Load(opts.key); ok {
   178  			rt, loaded := d.ChangeSet.LoadOrStoreRef(key)
   179  			// the first query in the entity changed period, evict the cache;
   180  			// if the new entity changed happen after the previous query, evict the cache
   181  			opts.evict = !loaded || t.After(rt)
   182  		} else if _, ok := d.ChangeSet.LoadRef(key); ok {
   183  			opts.evict = true
   184  			d.ChangeSet.DeleteRef(key)
   185  		}
   186  		if opts.ttl == 0 {
   187  			opts.ttl = d.KeyQueryTTL
   188  		}
   189  	case opts.key == "":
   190  		if opts.ttl == 0 {
   191  			opts.ttl = d.HashQueryTTL
   192  		}
   193  	case opts.key != "":
   194  		if _, ok := d.ChangeSet.Load(opts.key); ok {
   195  			opts.evict = true
   196  			d.ChangeSet.Delete(opts.key)
   197  		}
   198  		if opts.ttl == 0 {
   199  			opts.ttl = d.KeyQueryTTL
   200  		}
   201  	}
   202  	// use hashed key as the cache key
   203  	opts.key = key
   204  	if d.CachePrefix != "" {
   205  		opts.key = Key(d.CachePrefix) + opts.key
   206  	}
   207  	if opts.skipMode == cache.SkipCache {
   208  		return opts, errSkip
   209  	}
   210  	return opts, nil
   211  }
   212  
   213  // rawCopy copies the driver values by implementing
   214  // the sql.Scanner interface.
   215  type rawCopy struct {
   216  	values []driver.Value
   217  }
   218  
   219  func (c *rawCopy) Scan(src interface{}) error {
   220  	if b, ok := src.([]byte); ok {
   221  		b1 := make([]byte, len(b))
   222  		copy(b1, b)
   223  		src = b1
   224  	}
   225  	c.values[0] = src
   226  	c.values = c.values[1:]
   227  	return nil
   228  }
   229  
   230  // recorder represents an sql.Rows recorder that implements
   231  // the entgo.io/ent/dialect/sql.ColumnScanner interface.
   232  type recorder struct {
   233  	sql.ColumnScanner
   234  	values  [][]driver.Value
   235  	columns []string
   236  	done    bool
   237  	onClose func([]string, [][]driver.Value)
   238  }
   239  
   240  // Next wraps the underlying Next method
   241  func (r *recorder) Next() bool {
   242  	hasNext := r.ColumnScanner.Next()
   243  	r.done = !hasNext
   244  	return hasNext
   245  }
   246  
   247  // Scan copies database values for future use (by the repeater)
   248  // and assign them to the given destinations using the standard
   249  // database/sql.convertAssign function.
   250  func (r *recorder) Scan(dest ...any) error {
   251  	values := make([]driver.Value, len(dest))
   252  	args := make([]any, len(dest))
   253  	c := &rawCopy{values: values}
   254  	for i := range args {
   255  		args[i] = c
   256  	}
   257  	if err := r.ColumnScanner.Scan(args...); err != nil {
   258  		return err
   259  	}
   260  	for i := range values {
   261  		if err := convertAssign(dest[i], values[i]); err != nil {
   262  			return err
   263  		}
   264  	}
   265  	r.values = append(r.values, values)
   266  	return nil
   267  }
   268  
   269  // Columns wraps the underlying Column method and stores it in the recorder state.
   270  // The repeater.Columns cannot be called if the recorder method was not called before.
   271  // That means, raw scanning should be identical for identical queries.
   272  func (r *recorder) Columns() ([]string, error) {
   273  	columns, err := r.ColumnScanner.Columns()
   274  	if err != nil {
   275  		return nil, err
   276  	}
   277  	r.columns = columns
   278  	return columns, nil
   279  }
   280  
   281  func (r *recorder) Close() error {
   282  	if err := r.ColumnScanner.Close(); err != nil {
   283  		return err
   284  	}
   285  	// If we did not encounter any error during iteration,
   286  	// and we scanned all rows, we store it on cache.
   287  	if err := r.ColumnScanner.Err(); err == nil || r.done {
   288  		r.onClose(r.columns, r.values)
   289  	}
   290  	return nil
   291  }
   292  
   293  // repeater repeats columns scanning from cache history.
   294  type repeater struct {
   295  	columns []string
   296  	values  [][]driver.Value
   297  }
   298  
   299  func (*repeater) Close() error {
   300  	return nil
   301  }
   302  func (*repeater) ColumnTypes() ([]*stdsql.ColumnType, error) {
   303  	return nil, fmt.Errorf("entcache.ColumnTypes is not supported")
   304  }
   305  func (r *repeater) Columns() ([]string, error) {
   306  	return r.columns, nil
   307  }
   308  func (*repeater) Err() error {
   309  	return nil
   310  }
   311  func (r *repeater) Next() bool {
   312  	return len(r.values) > 0
   313  }
   314  
   315  func (r *repeater) NextResultSet() bool {
   316  	return len(r.values) > 0
   317  }
   318  
   319  func (r *repeater) Scan(dest ...any) error {
   320  	if !r.Next() {
   321  		return stdsql.ErrNoRows
   322  	}
   323  	for i, src := range r.values[0] {
   324  		if err := convertAssign(dest[i], src); err != nil {
   325  			return err
   326  		}
   327  	}
   328  	r.values = r.values[1:]
   329  	return nil
   330  }