github.com/argoproj/argo-cd/v3@v3.2.1/util/cache/redis.go (about)

     1  package cache
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"context"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"sync"
    13  	"time"
    14  
    15  	utilio "github.com/argoproj/argo-cd/v3/util/io"
    16  
    17  	rediscache "github.com/go-redis/cache/v9"
    18  	"github.com/redis/go-redis/v9"
    19  )
    20  
    21  type RedisCompressionType string
    22  
    23  var (
    24  	RedisCompressionNone RedisCompressionType = "none"
    25  	RedisCompressionGZip RedisCompressionType = "gzip"
    26  )
    27  
    28  func CompressionTypeFromString(s string) (RedisCompressionType, error) {
    29  	switch s {
    30  	case string(RedisCompressionNone):
    31  		return RedisCompressionNone, nil
    32  	case string(RedisCompressionGZip):
    33  		return RedisCompressionGZip, nil
    34  	}
    35  	return "", fmt.Errorf("unknown compression type: %s", s)
    36  }
    37  
    38  func NewRedisCache(client *redis.Client, expiration time.Duration, compressionType RedisCompressionType) CacheClient {
    39  	return &redisCache{
    40  		client:               client,
    41  		expiration:           expiration,
    42  		cache:                rediscache.New(&rediscache.Options{Redis: client}),
    43  		redisCompressionType: compressionType,
    44  	}
    45  }
    46  
    47  // compile-time validation of adherence of the CacheClient contract
    48  var _ CacheClient = &redisCache{}
    49  
    50  type redisCache struct {
    51  	expiration           time.Duration
    52  	client               *redis.Client
    53  	cache                *rediscache.Cache
    54  	redisCompressionType RedisCompressionType
    55  }
    56  
    57  func (r *redisCache) getKey(key string) string {
    58  	switch r.redisCompressionType {
    59  	case RedisCompressionGZip:
    60  		return key + ".gz"
    61  	default:
    62  		return key
    63  	}
    64  }
    65  
    66  func (r *redisCache) marshal(obj any) ([]byte, error) {
    67  	buf := bytes.NewBuffer([]byte{})
    68  	var w io.Writer = buf
    69  	if r.redisCompressionType == RedisCompressionGZip {
    70  		w = gzip.NewWriter(buf)
    71  	}
    72  	encoder := json.NewEncoder(w)
    73  
    74  	if err := encoder.Encode(obj); err != nil {
    75  		return nil, err
    76  	}
    77  	if flusher, ok := w.(interface{ Flush() error }); ok {
    78  		if err := flusher.Flush(); err != nil {
    79  			return nil, err
    80  		}
    81  	}
    82  	if closer, ok := w.(interface{ Close() error }); ok {
    83  		if err := closer.Close(); err != nil {
    84  			return nil, err
    85  		}
    86  	}
    87  	return buf.Bytes(), nil
    88  }
    89  
    90  func (r *redisCache) unmarshal(data []byte, obj any) error {
    91  	buf := bytes.NewReader(data)
    92  	var reader io.Reader = buf
    93  	if r.redisCompressionType == RedisCompressionGZip {
    94  		gzipReader, err := gzip.NewReader(buf)
    95  		if err != nil {
    96  			return err
    97  		}
    98  		reader = gzipReader
    99  	}
   100  	if err := json.NewDecoder(reader).Decode(obj); err != nil {
   101  		return fmt.Errorf("failed to decode cached data: %w", err)
   102  	}
   103  	return nil
   104  }
   105  
   106  func (r *redisCache) Rename(oldKey string, newKey string, _ time.Duration) error {
   107  	err := r.client.Rename(context.TODO(), r.getKey(oldKey), r.getKey(newKey)).Err()
   108  	if err != nil && err.Error() == "ERR no such key" {
   109  		err = ErrCacheMiss
   110  	}
   111  
   112  	return err
   113  }
   114  
   115  func (r *redisCache) Set(item *Item) error {
   116  	expiration := item.CacheActionOpts.Expiration
   117  	if expiration == 0 {
   118  		expiration = r.expiration
   119  	}
   120  
   121  	val, err := r.marshal(item.Object)
   122  	if err != nil {
   123  		return err
   124  	}
   125  
   126  	return r.cache.Set(&rediscache.Item{
   127  		Key:   r.getKey(item.Key),
   128  		Value: val,
   129  		TTL:   expiration,
   130  		SetNX: item.CacheActionOpts.DisableOverwrite,
   131  	})
   132  }
   133  
   134  func (r *redisCache) Get(key string, obj any) error {
   135  	var data []byte
   136  	err := r.cache.Get(context.TODO(), r.getKey(key), &data)
   137  	if errors.Is(err, rediscache.ErrCacheMiss) {
   138  		err = ErrCacheMiss
   139  	}
   140  	if err != nil {
   141  		return err
   142  	}
   143  	return r.unmarshal(data, obj)
   144  }
   145  
   146  func (r *redisCache) Delete(key string) error {
   147  	return r.cache.Delete(context.TODO(), r.getKey(key))
   148  }
   149  
   150  func (r *redisCache) OnUpdated(ctx context.Context, key string, callback func() error) error {
   151  	pubsub := r.client.Subscribe(ctx, key)
   152  	defer utilio.Close(pubsub)
   153  
   154  	ch := pubsub.Channel()
   155  	for {
   156  		select {
   157  		case <-ctx.Done():
   158  			return nil
   159  		case <-ch:
   160  			if err := callback(); err != nil {
   161  				return err
   162  			}
   163  		}
   164  	}
   165  }
   166  
   167  func (r *redisCache) NotifyUpdated(key string) error {
   168  	return r.client.Publish(context.TODO(), key, "").Err()
   169  }
   170  
   171  type MetricsRegistry interface {
   172  	IncRedisRequest(failed bool)
   173  	ObserveRedisRequestDuration(duration time.Duration)
   174  }
   175  
   176  type redisHook struct {
   177  	registry MetricsRegistry
   178  }
   179  
   180  func (rh *redisHook) DialHook(next redis.DialHook) redis.DialHook {
   181  	return func(ctx context.Context, network, addr string) (net.Conn, error) {
   182  		conn, err := next(ctx, network, addr)
   183  		return conn, err
   184  	}
   185  }
   186  
   187  func (rh *redisHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
   188  	return func(ctx context.Context, cmd redis.Cmder) error {
   189  		startTime := time.Now()
   190  
   191  		err := next(ctx, cmd)
   192  		rh.registry.IncRedisRequest(err != nil && !errors.Is(err, redis.Nil))
   193  		rh.registry.ObserveRedisRequestDuration(time.Since(startTime))
   194  
   195  		return err
   196  	}
   197  }
   198  
   199  func (redisHook) ProcessPipelineHook(_ redis.ProcessPipelineHook) redis.ProcessPipelineHook {
   200  	return nil
   201  }
   202  
   203  // CollectMetrics add transport wrapper that pushes metrics into the specified metrics registry
   204  // Lock should be shared between functions that can add/process a Redis hook.
   205  func CollectMetrics(client *redis.Client, registry MetricsRegistry, lock *sync.RWMutex) {
   206  	if lock != nil {
   207  		lock.Lock()
   208  		defer lock.Unlock()
   209  	}
   210  	client.AddHook(&redisHook{registry: registry})
   211  }