github.com/argoproj/argo-cd/v2@v2.10.5/util/cache/redis.go (about)

     1  package cache
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"context"
     7  	"encoding/json"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"time"
    12  
    13  	ioutil "github.com/argoproj/argo-cd/v2/util/io"
    14  
    15  	rediscache "github.com/go-redis/cache/v9"
    16  	"github.com/redis/go-redis/v9"
    17  )
    18  
    19  type RedisCompressionType string
    20  
    21  var (
    22  	RedisCompressionNone RedisCompressionType = "none"
    23  	RedisCompressionGZip RedisCompressionType = "gzip"
    24  )
    25  
    26  func CompressionTypeFromString(s string) (RedisCompressionType, error) {
    27  	switch s {
    28  	case string(RedisCompressionNone):
    29  		return RedisCompressionNone, nil
    30  	case string(RedisCompressionGZip):
    31  		return RedisCompressionGZip, nil
    32  	}
    33  	return "", fmt.Errorf("unknown compression type: %s", s)
    34  }
    35  
    36  func NewRedisCache(client *redis.Client, expiration time.Duration, compressionType RedisCompressionType) CacheClient {
    37  	return &redisCache{
    38  		client:               client,
    39  		expiration:           expiration,
    40  		cache:                rediscache.New(&rediscache.Options{Redis: client}),
    41  		redisCompressionType: compressionType,
    42  	}
    43  }
    44  
    45  // compile-time validation of adherance of the CacheClient contract
    46  var _ CacheClient = &redisCache{}
    47  
    48  type redisCache struct {
    49  	expiration           time.Duration
    50  	client               *redis.Client
    51  	cache                *rediscache.Cache
    52  	redisCompressionType RedisCompressionType
    53  }
    54  
    55  func (r *redisCache) getKey(key string) string {
    56  	switch r.redisCompressionType {
    57  	case RedisCompressionGZip:
    58  		return key + ".gz"
    59  	default:
    60  		return key
    61  	}
    62  }
    63  
    64  func (r *redisCache) marshal(obj interface{}) ([]byte, error) {
    65  	buf := bytes.NewBuffer([]byte{})
    66  	var w io.Writer = buf
    67  	if r.redisCompressionType == RedisCompressionGZip {
    68  		w = gzip.NewWriter(buf)
    69  	}
    70  	encoder := json.NewEncoder(w)
    71  
    72  	if err := encoder.Encode(obj); err != nil {
    73  		return nil, err
    74  	}
    75  	if flusher, ok := w.(interface{ Flush() error }); ok {
    76  		if err := flusher.Flush(); err != nil {
    77  			return nil, err
    78  		}
    79  	}
    80  	return buf.Bytes(), nil
    81  }
    82  
    83  func (r *redisCache) unmarshal(data []byte, obj interface{}) error {
    84  	buf := bytes.NewReader(data)
    85  	var reader io.Reader = buf
    86  	if r.redisCompressionType == RedisCompressionGZip {
    87  		if gzipReader, err := gzip.NewReader(buf); err != nil {
    88  			return err
    89  		} else {
    90  			reader = gzipReader
    91  		}
    92  	}
    93  	if err := json.NewDecoder(reader).Decode(obj); err != nil {
    94  		return fmt.Errorf("failed to decode cached data: %w", err)
    95  	}
    96  	return nil
    97  }
    98  
    99  func (r *redisCache) Set(item *Item) error {
   100  	expiration := item.Expiration
   101  	if expiration == 0 {
   102  		expiration = r.expiration
   103  	}
   104  
   105  	val, err := r.marshal(item.Object)
   106  	if err != nil {
   107  		return err
   108  	}
   109  
   110  	return r.cache.Set(&rediscache.Item{
   111  		Key:   r.getKey(item.Key),
   112  		Value: val,
   113  		TTL:   expiration,
   114  	})
   115  }
   116  
   117  func (r *redisCache) Get(key string, obj interface{}) error {
   118  	var data []byte
   119  	err := r.cache.Get(context.TODO(), r.getKey(key), &data)
   120  	if err == rediscache.ErrCacheMiss {
   121  		err = ErrCacheMiss
   122  	}
   123  	if err != nil {
   124  		return err
   125  	}
   126  	return r.unmarshal(data, obj)
   127  }
   128  
   129  func (r *redisCache) Delete(key string) error {
   130  	return r.cache.Delete(context.TODO(), r.getKey(key))
   131  }
   132  
   133  func (r *redisCache) OnUpdated(ctx context.Context, key string, callback func() error) error {
   134  	pubsub := r.client.Subscribe(ctx, key)
   135  	defer ioutil.Close(pubsub)
   136  
   137  	ch := pubsub.Channel()
   138  	for {
   139  		select {
   140  		case <-ctx.Done():
   141  			return nil
   142  		case <-ch:
   143  			if err := callback(); err != nil {
   144  				return err
   145  			}
   146  		}
   147  	}
   148  }
   149  
   150  func (r *redisCache) NotifyUpdated(key string) error {
   151  	return r.client.Publish(context.TODO(), key, "").Err()
   152  }
   153  
   154  type MetricsRegistry interface {
   155  	IncRedisRequest(failed bool)
   156  	ObserveRedisRequestDuration(duration time.Duration)
   157  }
   158  
   159  type redisHook struct {
   160  	registry MetricsRegistry
   161  }
   162  
   163  func (rh *redisHook) DialHook(next redis.DialHook) redis.DialHook {
   164  	return func(ctx context.Context, network, addr string) (net.Conn, error) {
   165  		conn, err := next(ctx, network, addr)
   166  		return conn, err
   167  	}
   168  }
   169  
   170  func (rh *redisHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
   171  	return func(ctx context.Context, cmd redis.Cmder) error {
   172  		startTime := time.Now()
   173  
   174  		err := next(ctx, cmd)
   175  		rh.registry.IncRedisRequest(err != nil && err != redis.Nil)
   176  		rh.registry.ObserveRedisRequestDuration(time.Since(startTime))
   177  
   178  		return err
   179  	}
   180  }
   181  
   182  func (redisHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
   183  	return nil
   184  }
   185  
   186  // CollectMetrics add transport wrapper that pushes metrics into the specified metrics registry
   187  func CollectMetrics(client *redis.Client, registry MetricsRegistry) {
   188  	client.AddHook(&redisHook{registry: registry})
   189  }