github.com/DeltaLaboratory/entcache@v0.1.1/driver.go (about)

     1  package entcache
     2  
     3  import (
     4  	"context"
     5  	stdsql "database/sql"
     6  	"database/sql/driver"
     7  	"errors"
     8  	"fmt"
     9  	"strings"
    10  	"sync/atomic"
    11  	"time"
    12  	_ "unsafe"
    13  
    14  	"entgo.io/ent/dialect"
    15  	"entgo.io/ent/dialect/sql"
    16  	"github.com/mitchellh/hashstructure/v2"
    17  )
    18  
    19  type (
    20  	// Options wrap the basic configuration cache options.
    21  	Options struct {
    22  		// TTL defines the period of time that an Entry
    23  		// is valid in the cache.
    24  		TTL time.Duration
    25  
    26  		// Cache defines the GetAddDeleter (cache implementation)
    27  		// for holding the cache entries. If no cache implementation
    28  		// was provided, an LRU cache with no limit is used.
    29  		Cache AddGetDeleter
    30  
    31  		// Hash defines an optional Hash function for converting
    32  		// a query and its arguments to a cache key. If no Hash
    33  		// function was provided, the DefaultHash is used.
    34  		Hash func(query string, args []any) (Key, error)
    35  
    36  		// Logf function. If provided, the Driver will call it with
    37  		// errors that cannot be handled.
    38  		Log func(...any)
    39  	}
    40  
    41  	// Option allows configuring the cache
    42  	// driver using functional options.
    43  	Option func(*Options)
    44  
    45  	// A Driver is an SQL-cached client.
    46  	// Users should use the constructor below for creating a new driver.
    47  	Driver struct {
    48  		dialect.Driver
    49  		*Options
    50  		stats Stats
    51  	}
    52  )
    53  
    54  // NewDriver returns a new Driver an existing driver and optional
    55  // configuration functions.
    56  // For example,
    57  //
    58  //	entcache.NewDriver(
    59  //		drv,
    60  //		entcache.TTL(time.Minute),
    61  //		entcache.Levels(
    62  //			NewLRU(256),
    63  //			NewRedis(redis.NewClient(&redis.Options{
    64  //				Addr: ":6379",
    65  //			})),
    66  //		)
    67  //	)
    68  func NewDriver(drv dialect.Driver, opts ...Option) *Driver {
    69  	options := &Options{Hash: DefaultHash, Cache: NewLRU(0)}
    70  	for _, opt := range opts {
    71  		opt(options)
    72  	}
    73  	return &Driver{
    74  		Driver:  drv,
    75  		Options: options,
    76  	}
    77  }
    78  
    79  // TTL configures the period of time that an Entry
    80  // is valid in the cache.
    81  func TTL(ttl time.Duration) Option {
    82  	return func(o *Options) {
    83  		o.TTL = ttl
    84  	}
    85  }
    86  
    87  // Hash configures an optional Hash function for
    88  // converting a query and its arguments to a cache key.
    89  func Hash(hash func(query string, args []any) (Key, error)) Option {
    90  	return func(o *Options) {
    91  		o.Hash = hash
    92  	}
    93  }
    94  
    95  // Levels configure the Driver to work with the given cache levels.
    96  // For example, in process LRU cache and a remote Redis cache.
    97  func Levels(levels ...AddGetDeleter) Option {
    98  	return func(o *Options) {
    99  		if len(levels) == 1 {
   100  			o.Cache = levels[0]
   101  		} else {
   102  			o.Cache = &multiLevel{levels: levels}
   103  		}
   104  	}
   105  }
   106  
   107  // ContextLevel configures the driver to work with context/request level cache.
   108  // Users that use this option should wrap the *http.Request context with the
   109  // cache value as follows:
   110  //
   111  //	ctx = entcache.NewContext(ctx)
   112  //
   113  //	ctx = entcache.NewContext(ctx, entcache.NewLRU(128))
   114  func ContextLevel() Option {
   115  	return func(o *Options) {
   116  		o.Cache = &contextLevel{}
   117  	}
   118  }
   119  
   120  // Query implements the Querier interface for the driver. It falls back to the
   121  // underlying wrapped driver in case of caching error.
   122  //
   123  // Note that the driver does not synchronize identical queries that are executed
   124  // concurrently. Hence, if two identical queries are executed at the ~same time, and
   125  // there is no cache entry for them, the driver will execute both of them and the
   126  // last successful one will be stored in the cache.
   127  func (d *Driver) Query(ctx context.Context, query string, args, v any) error {
   128  	// Check if the given statement looks like a standard Ent query (e.g., SELECT).
   129  	// Custom queries (e.g., CTE) or statements that are prefixed with comments are not supported.
   130  	// This check is mainly necessary because PostgreSQL and SQLite
   131  	// may execute an insert statement like "INSERT ... RETURNING" using Driver.Query.
   132  	if !strings.HasPrefix(query, "SELECT") && !strings.HasPrefix(query, "select") {
   133  		return d.Driver.Query(ctx, query, args, v)
   134  	}
   135  	vr, ok := v.(*sql.Rows)
   136  	if !ok {
   137  		return fmt.Errorf("entcache: invalid type %T. expect *sql.Rows", v)
   138  	}
   139  	argv, ok := args.([]any)
   140  	if !ok {
   141  		return fmt.Errorf("entcache: invalid type %T. expect []any for args", args)
   142  	}
   143  	opts, err := d.optionsFromContext(ctx, query, argv)
   144  	if err != nil {
   145  		return d.Driver.Query(ctx, query, args, v)
   146  	}
   147  	atomic.AddUint64(&d.stats.Gets, 1)
   148  	switch e, err := d.Cache.Get(ctx, opts.key); {
   149  	case err == nil:
   150  		atomic.AddUint64(&d.stats.Hits, 1)
   151  		vr.ColumnScanner = &repeater{columns: e.Columns, values: e.Values}
   152  	case errors.Is(err, ErrNotFound):
   153  		if err := d.Driver.Query(ctx, query, args, vr); err != nil {
   154  			return err
   155  		}
   156  		vr.ColumnScanner = &recorder{
   157  			ColumnScanner: vr.ColumnScanner,
   158  			onClose: func(columns []string, values [][]driver.Value) {
   159  				err := d.Cache.Add(ctx, opts.key, &Entry{Columns: columns, Values: values}, opts.ttl)
   160  				if err != nil && d.Log != nil {
   161  					atomic.AddUint64(&d.stats.Errors, 1)
   162  					d.Log(fmt.Sprintf("entcache: failed storing entry %v in cache: %v", opts.key, err))
   163  				}
   164  			},
   165  		}
   166  	default:
   167  		return d.Driver.Query(ctx, query, args, v)
   168  	}
   169  	return nil
   170  }
   171  
   172  // Stats return a copy of the cache statistics.
   173  func (d *Driver) Stats() Stats {
   174  	return Stats{
   175  		Gets:   atomic.LoadUint64(&d.stats.Gets),
   176  		Hits:   atomic.LoadUint64(&d.stats.Hits),
   177  		Errors: atomic.LoadUint64(&d.stats.Errors),
   178  	}
   179  }
   180  
   181  // QueryContext calls QueryContext of the underlying driver, or fails if it is not supported.
   182  // Note, this method is not part of the caching layer since Ent does not use it by default.
   183  func (d *Driver) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
   184  	drv, ok := d.Driver.(interface {
   185  		QueryContext(context.Context, string, ...any) (*sql.Rows, error)
   186  	})
   187  	if !ok {
   188  		return nil, fmt.Errorf("Driver.QueryContext is not supported")
   189  	}
   190  	return drv.QueryContext(ctx, query, args...)
   191  }
   192  
   193  // ExecContext calls ExecContext of the underlying driver, or fails if it is not supported.
   194  func (d *Driver) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
   195  	drv, ok := d.Driver.(interface {
   196  		ExecContext(context.Context, string, ...any) (sql.Result, error)
   197  	})
   198  	if !ok {
   199  		return nil, fmt.Errorf("Driver.ExecContext is not supported")
   200  	}
   201  	return drv.ExecContext(ctx, query, args...)
   202  }
   203  
   204  // errSkip tells the driver to skip cache layer.
   205  var errSkip = errors.New("entcache: skip cache")
   206  
   207  // optionsFromContext returns the injected options from the context, or its default value.
   208  func (d *Driver) optionsFromContext(ctx context.Context, query string, args []any) (ctxOptions, error) {
   209  	var opts ctxOptions
   210  	if c, ok := ctx.Value(ctxOptionsKey).(*ctxOptions); ok {
   211  		opts = *c
   212  	}
   213  	if opts.key == nil {
   214  		key, err := d.Hash(query, args)
   215  		if err != nil {
   216  			return opts, errSkip
   217  		}
   218  		opts.key = key
   219  	}
   220  	if opts.ttl == 0 {
   221  		opts.ttl = d.TTL
   222  	}
   223  	if opts.evict {
   224  		if err := d.Cache.Del(ctx, opts.key); err != nil {
   225  			return opts, err
   226  		}
   227  	}
   228  	if opts.skip {
   229  		return opts, errSkip
   230  	}
   231  	return opts, nil
   232  }
   233  
   234  // DefaultHash provides the default implementation for converting
   235  // a query and its argument to a cache key.
   236  func DefaultHash(query string, args []any) (Key, error) {
   237  	key, err := hashstructure.Hash(struct {
   238  		Q string
   239  		A []any
   240  	}{
   241  		Q: query,
   242  		A: args,
   243  	}, hashstructure.FormatV2, nil)
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  	return key, nil
   248  }
   249  
   250  // Stats represent the cache statistics of the driver.
   251  type Stats struct {
   252  	Gets   uint64
   253  	Hits   uint64
   254  	Errors uint64
   255  }
   256  
   257  // rawCopy copies the driver values by implementing
   258  // the sql.Scanner interface.
   259  type rawCopy struct {
   260  	values []driver.Value
   261  }
   262  
   263  func (c *rawCopy) Scan(src interface{}) error {
   264  	if b, ok := src.([]byte); ok {
   265  		b1 := make([]byte, len(b))
   266  		copy(b1, b)
   267  		src = b1
   268  	}
   269  	c.values[0] = src
   270  	c.values = c.values[1:]
   271  	return nil
   272  }
   273  
   274  // recorder represents a sql.Rows recorder that implements
   275  // the entgo.io/ent/dialect/sql.ColumnScanner interface.
   276  type recorder struct {
   277  	sql.ColumnScanner
   278  	values  [][]driver.Value
   279  	columns []string
   280  	done    bool
   281  	onClose func([]string, [][]driver.Value)
   282  }
   283  
   284  // Next wraps the underlying Next method
   285  func (r *recorder) Next() bool {
   286  	hasNext := r.ColumnScanner.Next()
   287  	r.done = !hasNext
   288  	return hasNext
   289  }
   290  
   291  // Scan copies database values for future use (by the repeater)
   292  // and assign them to the given destinations using the standard
   293  // database/sql.convertAssign function.
   294  func (r *recorder) Scan(dest ...any) error {
   295  	values := make([]driver.Value, len(dest))
   296  	args := make([]any, len(dest))
   297  	c := &rawCopy{values: values}
   298  	for i := range args {
   299  		args[i] = c
   300  	}
   301  	if err := r.ColumnScanner.Scan(args...); err != nil {
   302  		return err
   303  	}
   304  	for i := range values {
   305  		if err := convertAssign(dest[i], values[i]); err != nil {
   306  			return err
   307  		}
   308  	}
   309  	r.values = append(r.values, values)
   310  	return nil
   311  }
   312  
   313  // Columns wrap the underlying Column method and store it in the recorder state.
   314  // The repeater.Columns cannot be called if the recorder method was not called before.
   315  // That means raw scanning should be identical for identical queries.
   316  func (r *recorder) Columns() ([]string, error) {
   317  	columns, err := r.ColumnScanner.Columns()
   318  	if err != nil {
   319  		return nil, err
   320  	}
   321  	r.columns = columns
   322  	return columns, nil
   323  }
   324  
   325  func (r *recorder) Close() error {
   326  	if err := r.ColumnScanner.Close(); err != nil {
   327  		return err
   328  	}
   329  	// If we did not encounter any error during iteration,
   330  	// and we scanned all rows, we store it on cache.
   331  	if err := r.ColumnScanner.Err(); err == nil || r.done {
   332  		r.onClose(r.columns, r.values)
   333  	}
   334  	return nil
   335  }
   336  
   337  // repeater repeats columns scanning from cache history.
   338  type repeater struct {
   339  	columns []string
   340  	values  [][]driver.Value
   341  }
   342  
   343  func (*repeater) Close() error {
   344  	return nil
   345  }
   346  func (*repeater) ColumnTypes() ([]*stdsql.ColumnType, error) {
   347  	return nil, fmt.Errorf("entcache.ColumnTypes is not supported")
   348  }
   349  func (r *repeater) Columns() ([]string, error) {
   350  	return r.columns, nil
   351  }
   352  func (*repeater) Err() error {
   353  	return nil
   354  }
   355  func (r *repeater) Next() bool {
   356  	return len(r.values) > 0
   357  }
   358  func (r *repeater) NextResultSet() bool {
   359  	return len(r.values) > 0
   360  }
   361  
   362  func (r *repeater) Scan(dest ...any) error {
   363  	if !r.Next() {
   364  		return stdsql.ErrNoRows
   365  	}
   366  	for i, src := range r.values[0] {
   367  		if err := convertAssign(dest[i], src); err != nil {
   368  			return err
   369  		}
   370  	}
   371  	r.values = r.values[1:]
   372  	return nil
   373  }
   374  
   375  //go:linkname convertAssign database/sql.convertAssign
   376  func convertAssign(dest, src any) error