ariga.io/entcache@v0.1.1-0.20230620164151-0eb723a11c40/driver.go (about)

     1  package entcache
     2  
     3  import (
     4  	"context"
     5  	stdsql "database/sql"
     6  	"database/sql/driver"
     7  	"errors"
     8  	"fmt"
     9  	"strings"
    10  	"sync/atomic"
    11  	"time"
    12  	_ "unsafe"
    13  
    14  	"entgo.io/ent/dialect"
    15  	"entgo.io/ent/dialect/sql"
    16  	"github.com/mitchellh/hashstructure/v2"
    17  )
    18  
    19  type (
    20  	// Options wraps the basic configuration cache options.
    21  	Options struct {
    22  		// TTL defines the period of time that an Entry
    23  		// is valid in the cache.
    24  		TTL time.Duration
    25  
    26  		// Cache defines the GetAddDeleter (cache implementation)
    27  		// for holding the cache entries. If no cache implementation
    28  		// was provided, an LRU cache with no limit is used.
    29  		Cache AddGetDeleter
    30  
    31  		// Hash defines an optional Hash function for converting
    32  		// a query and its arguments to a cache key. If no Hash
    33  		// function was provided, the DefaultHash is used.
    34  		Hash func(query string, args []any) (Key, error)
    35  
    36  		// Logf function. If provided, the Driver will call it with
    37  		// errors that can not be handled.
    38  		Log func(...any)
    39  	}
    40  
    41  	// Option allows configuring the cache
    42  	// driver using functional options.
    43  	Option func(*Options)
    44  
    45  	// A Driver is an SQL cached client. Users should use the
    46  	// constructor below for creating new driver.
    47  	Driver struct {
    48  		dialect.Driver
    49  		*Options
    50  		stats Stats
    51  	}
    52  )
    53  
    54  // NewDriver returns a new Driver an existing driver and optional
    55  // configuration functions. For example:
    56  //
    57  //	entcache.NewDriver(
    58  //		drv,
    59  //		entcache.TTL(time.Minute),
    60  //		entcache.Levels(
    61  //			NewLRU(256),
    62  //			NewRedis(redis.NewClient(&redis.Options{
    63  //				Addr: ":6379",
    64  //			})),
    65  //		)
    66  //	)
    67  func NewDriver(drv dialect.Driver, opts ...Option) *Driver {
    68  	options := &Options{Hash: DefaultHash, Cache: NewLRU(0)}
    69  	for _, opt := range opts {
    70  		opt(options)
    71  	}
    72  	return &Driver{
    73  		Driver:  drv,
    74  		Options: options,
    75  	}
    76  }
    77  
    78  // TTL configures the period of time that an Entry
    79  // is valid in the cache.
    80  func TTL(ttl time.Duration) Option {
    81  	return func(o *Options) {
    82  		o.TTL = ttl
    83  	}
    84  }
    85  
    86  // Hash configures an optional Hash function for
    87  // converting a query and its arguments to a cache key.
    88  func Hash(hash func(query string, args []any) (Key, error)) Option {
    89  	return func(o *Options) {
    90  		o.Hash = hash
    91  	}
    92  }
    93  
    94  // Levels configures the Driver to work with the given cache levels.
    95  // For example, in process LRU cache and a remote Redis cache.
    96  func Levels(levels ...AddGetDeleter) Option {
    97  	return func(o *Options) {
    98  		if len(levels) == 1 {
    99  			o.Cache = levels[0]
   100  		} else {
   101  			o.Cache = &multiLevel{levels: levels}
   102  		}
   103  	}
   104  }
   105  
   106  // ContextLevel configures the driver to work with context/request level cache.
   107  // Users that use this option, should wraps the *http.Request context with the
   108  // cache value as follows:
   109  //
   110  //	ctx = entcache.NewContext(ctx)
   111  //
   112  //	ctx = entcache.NewContext(ctx, entcache.NewLRU(128))
   113  func ContextLevel() Option {
   114  	return func(o *Options) {
   115  		o.Cache = &contextLevel{}
   116  	}
   117  }
   118  
   119  // Query implements the Querier interface for the driver. It falls back to the
   120  // underlying wrapped driver in case of caching error.
   121  //
   122  // Note that, the driver does not synchronize identical queries that are executed
   123  // concurrently. Hence, if 2 identical queries are executed at the ~same time, and
   124  // there is no cache entry for them, the driver will execute both of them and the
   125  // last successful one will be stored in the cache.
   126  func (d *Driver) Query(ctx context.Context, query string, args, v any) error {
   127  	// Check if the given statement looks like a standard Ent query (e.g. SELECT).
   128  	// Custom queries (e.g. CTE) or statements that are prefixed with comments are
   129  	// not supported. This check is mainly necessary, because PostgreSQL and SQLite
   130  	// may execute insert statement like "INSERT ... RETURNING" using Driver.Query.
   131  	if !strings.HasPrefix(query, "SELECT") && !strings.HasPrefix(query, "select") {
   132  		return d.Driver.Query(ctx, query, args, v)
   133  	}
   134  	vr, ok := v.(*sql.Rows)
   135  	if !ok {
   136  		return fmt.Errorf("entcache: invalid type %T. expect *sql.Rows", v)
   137  	}
   138  	argv, ok := args.([]any)
   139  	if !ok {
   140  		return fmt.Errorf("entcache: invalid type %T. expect []interface{} for args", args)
   141  	}
   142  	opts, err := d.optionsFromContext(ctx, query, argv)
   143  	if err != nil {
   144  		return d.Driver.Query(ctx, query, args, v)
   145  	}
   146  	atomic.AddUint64(&d.stats.Gets, 1)
   147  	switch e, err := d.Cache.Get(ctx, opts.key); {
   148  	case err == nil:
   149  		atomic.AddUint64(&d.stats.Hits, 1)
   150  		vr.ColumnScanner = &repeater{columns: e.Columns, values: e.Values}
   151  	case err == ErrNotFound:
   152  		if err := d.Driver.Query(ctx, query, args, vr); err != nil {
   153  			return err
   154  		}
   155  		vr.ColumnScanner = &recorder{
   156  			ColumnScanner: vr.ColumnScanner,
   157  			onClose: func(columns []string, values [][]driver.Value) {
   158  				err := d.Cache.Add(ctx, opts.key, &Entry{Columns: columns, Values: values}, opts.ttl)
   159  				if err != nil && d.Log != nil {
   160  					atomic.AddUint64(&d.stats.Errors, 1)
   161  					d.Log(fmt.Sprintf("entcache: failed storing entry %v in cache: %v", opts.key, err))
   162  				}
   163  			},
   164  		}
   165  	default:
   166  		return d.Driver.Query(ctx, query, args, v)
   167  	}
   168  	return nil
   169  }
   170  
   171  // Stats returns a copy of the cache statistics.
   172  func (d *Driver) Stats() Stats {
   173  	return Stats{
   174  		Gets:   atomic.LoadUint64(&d.stats.Gets),
   175  		Hits:   atomic.LoadUint64(&d.stats.Hits),
   176  		Errors: atomic.LoadUint64(&d.stats.Errors),
   177  	}
   178  }
   179  
   180  // QueryContext calls QueryContext of the underlying driver, or fails if it is not supported.
   181  // Note, this method is not part of the caching layer since Ent does not use it by default.
   182  func (d *Driver) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
   183  	drv, ok := d.Driver.(interface {
   184  		QueryContext(context.Context, string, ...any) (*sql.Rows, error)
   185  	})
   186  	if !ok {
   187  		return nil, fmt.Errorf("Driver.QueryContext is not supported")
   188  	}
   189  	return drv.QueryContext(ctx, query, args...)
   190  }
   191  
   192  // ExecContext calls ExecContext of the underlying driver, or fails if it is not supported.
   193  func (d *Driver) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
   194  	drv, ok := d.Driver.(interface {
   195  		ExecContext(context.Context, string, ...any) (sql.Result, error)
   196  	})
   197  	if !ok {
   198  		return nil, fmt.Errorf("Driver.ExecContext is not supported")
   199  	}
   200  	return drv.ExecContext(ctx, query, args...)
   201  }
   202  
   203  // errSkip tells the driver to skip cache layer.
   204  var errSkip = errors.New("entcache: skip cache")
   205  
   206  // optionsFromContext returns the injected options from the context, or its default value.
   207  func (d *Driver) optionsFromContext(ctx context.Context, query string, args []any) (ctxOptions, error) {
   208  	var opts ctxOptions
   209  	if c, ok := ctx.Value(ctxOptionsKey).(*ctxOptions); ok {
   210  		opts = *c
   211  	}
   212  	if opts.key == nil {
   213  		key, err := d.Hash(query, args)
   214  		if err != nil {
   215  			return opts, errSkip
   216  		}
   217  		opts.key = key
   218  	}
   219  	if opts.ttl == 0 {
   220  		opts.ttl = d.TTL
   221  	}
   222  	if opts.evict {
   223  		if err := d.Cache.Del(ctx, opts.key); err != nil {
   224  			return opts, err
   225  		}
   226  	}
   227  	if opts.skip {
   228  		return opts, errSkip
   229  	}
   230  	return opts, nil
   231  }
   232  
   233  // DefaultHash provides the default implementation for converting
   234  // a query and its argument to a cache key.
   235  func DefaultHash(query string, args []any) (Key, error) {
   236  	key, err := hashstructure.Hash(struct {
   237  		Q string
   238  		A []any
   239  	}{
   240  		Q: query,
   241  		A: args,
   242  	}, hashstructure.FormatV2, nil)
   243  	if err != nil {
   244  		return nil, err
   245  	}
   246  	return key, nil
   247  }
   248  
   249  // Stats represents the cache statistics of the driver.
   250  type Stats struct {
   251  	Gets   uint64
   252  	Hits   uint64
   253  	Errors uint64
   254  }
   255  
   256  // rawCopy copies the driver values by implementing
   257  // the sql.Scanner interface.
   258  type rawCopy struct {
   259  	values []driver.Value
   260  }
   261  
   262  func (c *rawCopy) Scan(src interface{}) error {
   263  	if b, ok := src.([]byte); ok {
   264  		b1 := make([]byte, len(b))
   265  		copy(b1, b)
   266  		src = b1
   267  	}
   268  	c.values[0] = src
   269  	c.values = c.values[1:]
   270  	return nil
   271  }
   272  
   273  // recorder represents an sql.Rows recorder that implements
   274  // the entgo.io/ent/dialect/sql.ColumnScanner interface.
   275  type recorder struct {
   276  	sql.ColumnScanner
   277  	values  [][]driver.Value
   278  	columns []string
   279  	done    bool
   280  	onClose func([]string, [][]driver.Value)
   281  }
   282  
   283  // Next wraps the underlying Next method
   284  func (r *recorder) Next() bool {
   285  	hasNext := r.ColumnScanner.Next()
   286  	r.done = !hasNext
   287  	return hasNext
   288  }
   289  
   290  // Scan copies database values for future use (by the repeater)
   291  // and assign them to the given destinations using the standard
   292  // database/sql.convertAssign function.
   293  func (r *recorder) Scan(dest ...any) error {
   294  	values := make([]driver.Value, len(dest))
   295  	args := make([]any, len(dest))
   296  	c := &rawCopy{values: values}
   297  	for i := range args {
   298  		args[i] = c
   299  	}
   300  	if err := r.ColumnScanner.Scan(args...); err != nil {
   301  		return err
   302  	}
   303  	for i := range values {
   304  		if err := convertAssign(dest[i], values[i]); err != nil {
   305  			return err
   306  		}
   307  	}
   308  	r.values = append(r.values, values)
   309  	return nil
   310  }
   311  
   312  // Columns wraps the underlying Column method and stores it in the recorder state.
   313  // The repeater.Columns cannot be called if the recorder method was not called before.
   314  // That means, raw scanning should be identical for identical queries.
   315  func (r *recorder) Columns() ([]string, error) {
   316  	columns, err := r.ColumnScanner.Columns()
   317  	if err != nil {
   318  		return nil, err
   319  	}
   320  	r.columns = columns
   321  	return columns, nil
   322  }
   323  
   324  func (r *recorder) Close() error {
   325  	if err := r.ColumnScanner.Close(); err != nil {
   326  		return err
   327  	}
   328  	// If we did not encounter any error during iteration,
   329  	// and we scanned all rows, we store it on cache.
   330  	if err := r.ColumnScanner.Err(); err == nil || r.done {
   331  		r.onClose(r.columns, r.values)
   332  	}
   333  	return nil
   334  }
   335  
   336  // repeater repeats columns scanning from cache history.
   337  type repeater struct {
   338  	columns []string
   339  	values  [][]driver.Value
   340  }
   341  
   342  func (*repeater) Close() error {
   343  	return nil
   344  }
   345  func (*repeater) ColumnTypes() ([]*stdsql.ColumnType, error) {
   346  	return nil, fmt.Errorf("entcache.ColumnTypes is not supported")
   347  }
   348  func (r *repeater) Columns() ([]string, error) {
   349  	return r.columns, nil
   350  }
   351  func (*repeater) Err() error {
   352  	return nil
   353  }
   354  func (r *repeater) Next() bool {
   355  	return len(r.values) > 0
   356  }
   357  func (r *repeater) NextResultSet() bool {
   358  	return len(r.values) > 0
   359  }
   360  
   361  func (r *repeater) Scan(dest ...any) error {
   362  	if !r.Next() {
   363  		return stdsql.ErrNoRows
   364  	}
   365  	for i, src := range r.values[0] {
   366  		if err := convertAssign(dest[i], src); err != nil {
   367  			return err
   368  		}
   369  	}
   370  	r.values = r.values[1:]
   371  	return nil
   372  }
   373  
   374  //go:linkname convertAssign database/sql.convertAssign
   375  func convertAssign(dest, src any) error