github.com/argoproj/argo-cd/v2@v2.10.5/util/cache/cache.go (about)

     1  package cache
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math"
     7  	"os"
     8  	"time"
     9  
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  
    13  	"github.com/redis/go-redis/v9"
    14  	"github.com/spf13/cobra"
    15  
    16  	"github.com/argoproj/argo-cd/v2/common"
    17  	certutil "github.com/argoproj/argo-cd/v2/util/cert"
    18  	"github.com/argoproj/argo-cd/v2/util/env"
    19  )
    20  
    21  const (
    22  	// envRedisPassword is an env variable name which stores redis password
    23  	envRedisPassword = "REDIS_PASSWORD"
    24  	// envRedisUsername is an env variable name which stores redis username (for acl setup)
    25  	envRedisUsername = "REDIS_USERNAME"
    26  	// envRedisRetryCount is an env variable name which stores redis retry count
    27  	envRedisRetryCount = "REDIS_RETRY_COUNT"
    28  	// defaultRedisRetryCount holds default number of retries
    29  	defaultRedisRetryCount = 3
    30  )
    31  
    32  const (
    33  	// CLIFlagRedisCompress is a cli flag name to define the redis compression setting for data sent to redis
    34  	CLIFlagRedisCompress = "redis-compress"
    35  )
    36  
    37  func NewCache(client CacheClient) *Cache {
    38  	return &Cache{client}
    39  }
    40  
    41  func buildRedisClient(redisAddress, password, username string, redisDB, maxRetries int, tlsConfig *tls.Config) *redis.Client {
    42  	opts := &redis.Options{
    43  		Addr:       redisAddress,
    44  		Password:   password,
    45  		DB:         redisDB,
    46  		MaxRetries: maxRetries,
    47  		TLSConfig:  tlsConfig,
    48  		Username:   username,
    49  	}
    50  
    51  	client := redis.NewClient(opts)
    52  
    53  	client.AddHook(redis.Hook(NewArgoRedisHook(func() {
    54  		*client = *buildRedisClient(redisAddress, password, username, redisDB, maxRetries, tlsConfig)
    55  	})))
    56  
    57  	return client
    58  }
    59  
    60  func buildFailoverRedisClient(sentinelMaster, password, username string, redisDB, maxRetries int, tlsConfig *tls.Config, sentinelAddresses []string) *redis.Client {
    61  	opts := &redis.FailoverOptions{
    62  		MasterName:    sentinelMaster,
    63  		SentinelAddrs: sentinelAddresses,
    64  		DB:            redisDB,
    65  		Password:      password,
    66  		MaxRetries:    maxRetries,
    67  		TLSConfig:     tlsConfig,
    68  		Username:      username,
    69  	}
    70  
    71  	client := redis.NewFailoverClient(opts)
    72  
    73  	client.AddHook(redis.Hook(NewArgoRedisHook(func() {
    74  		*client = *buildFailoverRedisClient(sentinelMaster, password, username, redisDB, maxRetries, tlsConfig, sentinelAddresses)
    75  	})))
    76  
    77  	return client
    78  }
    79  
    80  // AddCacheFlagsToCmd adds flags which control caching to the specified command
    81  func AddCacheFlagsToCmd(cmd *cobra.Command, opts ...func(client *redis.Client)) func() (*Cache, error) {
    82  	redisAddress := ""
    83  	sentinelAddresses := make([]string, 0)
    84  	sentinelMaster := ""
    85  	redisDB := 0
    86  	redisCACertificate := ""
    87  	redisClientCertificate := ""
    88  	redisClientKey := ""
    89  	redisUseTLS := false
    90  	insecureRedis := false
    91  	compressionStr := ""
    92  	var defaultCacheExpiration time.Duration
    93  
    94  	cmd.Flags().StringVar(&redisAddress, "redis", env.StringFromEnv("REDIS_SERVER", ""), "Redis server hostname and port (e.g. argocd-redis:6379). ")
    95  	cmd.Flags().IntVar(&redisDB, "redisdb", env.ParseNumFromEnv("REDISDB", 0, 0, math.MaxInt32), "Redis database.")
    96  	cmd.Flags().StringArrayVar(&sentinelAddresses, "sentinel", []string{}, "Redis sentinel hostname and port (e.g. argocd-redis-ha-announce-0:6379). ")
    97  	cmd.Flags().StringVar(&sentinelMaster, "sentinelmaster", "master", "Redis sentinel master group name.")
    98  	cmd.Flags().DurationVar(&defaultCacheExpiration, "default-cache-expiration", env.ParseDurationFromEnv("ARGOCD_DEFAULT_CACHE_EXPIRATION", 24*time.Hour, 0, math.MaxInt64), "Cache expiration default")
    99  	cmd.Flags().BoolVar(&redisUseTLS, "redis-use-tls", false, "Use TLS when connecting to Redis. ")
   100  	cmd.Flags().StringVar(&redisClientCertificate, "redis-client-certificate", "", "Path to Redis client certificate (e.g. /etc/certs/redis/client.crt).")
   101  	cmd.Flags().StringVar(&redisClientKey, "redis-client-key", "", "Path to Redis client key (e.g. /etc/certs/redis/client.crt).")
   102  	cmd.Flags().BoolVar(&insecureRedis, "redis-insecure-skip-tls-verify", false, "Skip Redis server certificate validation.")
   103  	cmd.Flags().StringVar(&redisCACertificate, "redis-ca-certificate", "", "Path to Redis server CA certificate (e.g. /etc/certs/redis/ca.crt). If not specified, system trusted CAs will be used for server certificate validation.")
   104  	cmd.Flags().StringVar(&compressionStr, CLIFlagRedisCompress, env.StringFromEnv("REDIS_COMPRESSION", string(RedisCompressionGZip)), "Enable compression for data sent to Redis with the required compression algorithm. (possible values: gzip, none)")
   105  	return func() (*Cache, error) {
   106  		var tlsConfig *tls.Config = nil
   107  		if redisUseTLS {
   108  			tlsConfig = &tls.Config{}
   109  			if redisClientCertificate != "" {
   110  				clientCert, err := tls.LoadX509KeyPair(redisClientCertificate, redisClientKey)
   111  				if err != nil {
   112  					return nil, err
   113  				}
   114  				tlsConfig.Certificates = []tls.Certificate{clientCert}
   115  			}
   116  			if insecureRedis {
   117  				tlsConfig.InsecureSkipVerify = true
   118  			} else if redisCACertificate != "" {
   119  				redisCA, err := certutil.ParseTLSCertificatesFromPath(redisCACertificate)
   120  				if err != nil {
   121  					return nil, err
   122  				}
   123  				tlsConfig.RootCAs = certutil.GetCertPoolFromPEMData(redisCA)
   124  			} else {
   125  				var err error
   126  				tlsConfig.RootCAs, err = x509.SystemCertPool()
   127  				if err != nil {
   128  					return nil, err
   129  				}
   130  			}
   131  		}
   132  		password := os.Getenv(envRedisPassword)
   133  		username := os.Getenv(envRedisUsername)
   134  		maxRetries := env.ParseNumFromEnv(envRedisRetryCount, defaultRedisRetryCount, 0, math.MaxInt32)
   135  		compression, err := CompressionTypeFromString(compressionStr)
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  		if len(sentinelAddresses) > 0 {
   140  			client := buildFailoverRedisClient(sentinelMaster, password, username, redisDB, maxRetries, tlsConfig, sentinelAddresses)
   141  			for i := range opts {
   142  				opts[i](client)
   143  			}
   144  			return NewCache(NewRedisCache(client, defaultCacheExpiration, compression)), nil
   145  		}
   146  		if redisAddress == "" {
   147  			redisAddress = common.DefaultRedisAddr
   148  		}
   149  
   150  		client := buildRedisClient(redisAddress, password, username, redisDB, maxRetries, tlsConfig)
   151  		for i := range opts {
   152  			opts[i](client)
   153  		}
   154  		return NewCache(NewRedisCache(client, defaultCacheExpiration, compression)), nil
   155  	}
   156  }
   157  
   158  // Cache provides strongly types methods to store and retrieve values from shared cache
   159  type Cache struct {
   160  	client CacheClient
   161  }
   162  
   163  func (c *Cache) GetClient() CacheClient {
   164  	return c.client
   165  }
   166  
   167  func (c *Cache) SetClient(client CacheClient) {
   168  	c.client = client
   169  }
   170  
   171  func (c *Cache) SetItem(key string, item interface{}, expiration time.Duration, delete bool) error {
   172  	key = fmt.Sprintf("%s|%s", key, common.CacheVersion)
   173  	if delete {
   174  		return c.client.Delete(key)
   175  	} else {
   176  		if item == nil {
   177  			return fmt.Errorf("cannot set item to nil for key %s", key)
   178  		}
   179  		return c.client.Set(&Item{Object: item, Key: key, Expiration: expiration})
   180  	}
   181  }
   182  
   183  func (c *Cache) GetItem(key string, item interface{}) error {
   184  	if item == nil {
   185  		return fmt.Errorf("cannot get item into a nil for key %s", key)
   186  	}
   187  	key = fmt.Sprintf("%s|%s", key, common.CacheVersion)
   188  	return c.client.Get(key, item)
   189  }
   190  
   191  func (c *Cache) OnUpdated(ctx context.Context, key string, callback func() error) error {
   192  	return c.client.OnUpdated(ctx, fmt.Sprintf("%s|%s", key, common.CacheVersion), callback)
   193  }
   194  
   195  func (c *Cache) NotifyUpdated(key string) error {
   196  	return c.client.NotifyUpdated(fmt.Sprintf("%s|%s", key, common.CacheVersion))
   197  }