github.com/argoproj/argo-cd/v3@v3.2.1/util/cache/cache.go (about)

     1  package cache
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"errors"
     8  	"fmt"
     9  	"math"
    10  	"os"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/redis/go-redis/v9"
    15  	log "github.com/sirupsen/logrus"
    16  	"github.com/spf13/cobra"
    17  
    18  	"github.com/argoproj/argo-cd/v3/common"
    19  	certutil "github.com/argoproj/argo-cd/v3/util/cert"
    20  	"github.com/argoproj/argo-cd/v3/util/env"
    21  )
    22  
    23  const (
    24  	// envRedisPassword is an env variable name which stores redis password
    25  	envRedisPassword = "REDIS_PASSWORD"
    26  	// envRedisUsername is an env variable name which stores redis username (for acl setup)
    27  	envRedisUsername = "REDIS_USERNAME"
    28  	// envRedisRetryCount is an env variable name which stores redis retry count
    29  	envRedisRetryCount = "REDIS_RETRY_COUNT"
    30  	// defaultRedisRetryCount holds default number of retries
    31  	defaultRedisRetryCount = 3
    32  	// envRedisSentinelPassword is an env variable name which stores redis sentinel password
    33  	envRedisSentinelPassword = "REDIS_SENTINEL_PASSWORD"
    34  	// envRedisSentinelUsername is an env variable name which stores redis sentinel username
    35  	envRedisSentinelUsername = "REDIS_SENTINEL_USERNAME"
    36  )
    37  
    38  const (
    39  	// CLIFlagRedisCompress is a cli flag name to define the redis compression setting for data sent to redis
    40  	CLIFlagRedisCompress = "redis-compress"
    41  )
    42  
    43  func NewCache(client CacheClient) *Cache {
    44  	return &Cache{client}
    45  }
    46  
    47  func buildRedisClient(redisAddress, password, username string, redisDB, maxRetries int, tlsConfig *tls.Config) *redis.Client {
    48  	opts := &redis.Options{
    49  		Addr:       redisAddress,
    50  		Password:   password,
    51  		DB:         redisDB,
    52  		MaxRetries: maxRetries,
    53  		TLSConfig:  tlsConfig,
    54  		Username:   username,
    55  	}
    56  
    57  	client := redis.NewClient(opts)
    58  
    59  	client.AddHook(redis.Hook(NewArgoRedisHook(func() {
    60  		*client = *buildRedisClient(redisAddress, password, username, redisDB, maxRetries, tlsConfig)
    61  	})))
    62  
    63  	return client
    64  }
    65  
    66  func buildFailoverRedisClient(sentinelMaster, sentinelUsername, sentinelPassword, password, username string, redisDB, maxRetries int, tlsConfig *tls.Config, sentinelAddresses []string) *redis.Client {
    67  	opts := &redis.FailoverOptions{
    68  		MasterName:       sentinelMaster,
    69  		SentinelAddrs:    sentinelAddresses,
    70  		DB:               redisDB,
    71  		Password:         password,
    72  		MaxRetries:       maxRetries,
    73  		TLSConfig:        tlsConfig,
    74  		Username:         username,
    75  		SentinelUsername: sentinelUsername,
    76  		SentinelPassword: sentinelPassword,
    77  	}
    78  
    79  	client := redis.NewFailoverClient(opts)
    80  
    81  	client.AddHook(redis.Hook(NewArgoRedisHook(func() {
    82  		*client = *buildFailoverRedisClient(sentinelMaster, sentinelUsername, sentinelPassword, password, username, redisDB, maxRetries, tlsConfig, sentinelAddresses)
    83  	})))
    84  
    85  	return client
    86  }
    87  
    88  type Options struct {
    89  	FlagPrefix      string
    90  	OnClientCreated func(client *redis.Client)
    91  }
    92  
    93  func (o *Options) callOnClientCreated(client *redis.Client) {
    94  	if o.OnClientCreated != nil {
    95  		o.OnClientCreated(client)
    96  	}
    97  }
    98  
    99  func (o *Options) getEnvPrefix() string {
   100  	return strings.ReplaceAll(strings.ToUpper(o.FlagPrefix), "-", "_")
   101  }
   102  
   103  func mergeOptions(opts ...Options) Options {
   104  	var result Options
   105  	for _, o := range opts {
   106  		if o.FlagPrefix != "" {
   107  			result.FlagPrefix = o.FlagPrefix
   108  		}
   109  		if o.OnClientCreated != nil {
   110  			result.OnClientCreated = o.OnClientCreated
   111  		}
   112  	}
   113  	return result
   114  }
   115  
   116  func getFlagVal[T any](cmd *cobra.Command, o Options, name string, getVal func(name string) (T, error)) func() T {
   117  	return func() T {
   118  		var res T
   119  		var err error
   120  		if o.FlagPrefix != "" && cmd.Flags().Changed(o.FlagPrefix+name) {
   121  			res, err = getVal(o.FlagPrefix + name)
   122  		} else {
   123  			res, err = getVal(name)
   124  		}
   125  		if err != nil {
   126  			panic(err)
   127  		}
   128  		return res
   129  	}
   130  }
   131  
   132  // AddCacheFlagsToCmd adds flags which control caching to the specified command
   133  func AddCacheFlagsToCmd(cmd *cobra.Command, opts ...Options) func() (*Cache, error) {
   134  	redisAddress := ""
   135  	sentinelAddresses := make([]string, 0)
   136  	sentinelMaster := ""
   137  	redisDB := 0
   138  	redisCACertificate := ""
   139  	redisClientCertificate := ""
   140  	redisClientKey := ""
   141  	redisUseTLS := false
   142  	insecureRedis := false
   143  	compressionStr := ""
   144  	opt := mergeOptions(opts...)
   145  	var defaultCacheExpiration time.Duration
   146  
   147  	cmd.Flags().StringVar(&redisAddress, opt.FlagPrefix+"redis", env.StringFromEnv(opt.getEnvPrefix()+"REDIS_SERVER", ""), "Redis server hostname and port (e.g. argocd-redis:6379). ")
   148  	redisAddressSrc := getFlagVal(cmd, opt, "redis", cmd.Flags().GetString)
   149  	cmd.Flags().IntVar(&redisDB, opt.FlagPrefix+"redisdb", env.ParseNumFromEnv(opt.getEnvPrefix()+"REDISDB", 0, 0, math.MaxInt32), "Redis database.")
   150  	redisDBSrc := getFlagVal(cmd, opt, "redisdb", cmd.Flags().GetInt)
   151  	cmd.Flags().StringArrayVar(&sentinelAddresses, opt.FlagPrefix+"sentinel", []string{}, "Redis sentinel hostname and port (e.g. argocd-redis-ha-announce-0:6379). ")
   152  	sentinelAddressesSrc := getFlagVal(cmd, opt, "sentinel", cmd.Flags().GetStringArray)
   153  	cmd.Flags().StringVar(&sentinelMaster, opt.FlagPrefix+"sentinelmaster", "master", "Redis sentinel master group name.")
   154  	sentinelMasterSrc := getFlagVal(cmd, opt, "sentinelmaster", cmd.Flags().GetString)
   155  	cmd.Flags().DurationVar(&defaultCacheExpiration, opt.FlagPrefix+"default-cache-expiration", env.ParseDurationFromEnv("ARGOCD_DEFAULT_CACHE_EXPIRATION", 24*time.Hour, 0, math.MaxInt64), "Cache expiration default")
   156  	defaultCacheExpirationSrc := getFlagVal(cmd, opt, "default-cache-expiration", cmd.Flags().GetDuration)
   157  	cmd.Flags().BoolVar(&redisUseTLS, opt.FlagPrefix+"redis-use-tls", false, "Use TLS when connecting to Redis. ")
   158  	redisUseTLSSrc := getFlagVal(cmd, opt, "redis-use-tls", cmd.Flags().GetBool)
   159  	cmd.Flags().StringVar(&redisClientCertificate, opt.FlagPrefix+"redis-client-certificate", "", "Path to Redis client certificate (e.g. /etc/certs/redis/client.crt).")
   160  	redisClientCertificateSrc := getFlagVal(cmd, opt, "redis-client-certificate", cmd.Flags().GetString)
   161  	cmd.Flags().StringVar(&redisClientKey, opt.FlagPrefix+"redis-client-key", "", "Path to Redis client key (e.g. /etc/certs/redis/client.crt).")
   162  	redisClientKeySrc := getFlagVal(cmd, opt, "redis-client-key", cmd.Flags().GetString)
   163  	cmd.Flags().BoolVar(&insecureRedis, opt.FlagPrefix+"redis-insecure-skip-tls-verify", false, "Skip Redis server certificate validation.")
   164  	insecureRedisSrc := getFlagVal(cmd, opt, "redis-insecure-skip-tls-verify", cmd.Flags().GetBool)
   165  	cmd.Flags().StringVar(&redisCACertificate, opt.FlagPrefix+"redis-ca-certificate", "", "Path to Redis server CA certificate (e.g. /etc/certs/redis/ca.crt). If not specified, system trusted CAs will be used for server certificate validation.")
   166  	redisCACertificateSrc := getFlagVal(cmd, opt, "redis-ca-certificate", cmd.Flags().GetString)
   167  	cmd.Flags().StringVar(&compressionStr, opt.FlagPrefix+CLIFlagRedisCompress, env.StringFromEnv(opt.getEnvPrefix()+"REDIS_COMPRESSION", string(RedisCompressionGZip)), "Enable compression for data sent to Redis with the required compression algorithm. (possible values: gzip, none)")
   168  	compressionStrSrc := getFlagVal(cmd, opt, CLIFlagRedisCompress, cmd.Flags().GetString)
   169  	return func() (*Cache, error) {
   170  		redisAddress := redisAddressSrc()
   171  		redisDB := redisDBSrc()
   172  		sentinelAddresses := sentinelAddressesSrc()
   173  		sentinelMaster := sentinelMasterSrc()
   174  		defaultCacheExpiration := defaultCacheExpirationSrc()
   175  		redisUseTLS := redisUseTLSSrc()
   176  		redisClientCertificate := redisClientCertificateSrc()
   177  		redisClientKey := redisClientKeySrc()
   178  		insecureRedis := insecureRedisSrc()
   179  		redisCACertificate := redisCACertificateSrc()
   180  		compressionStr := compressionStrSrc()
   181  
   182  		var tlsConfig *tls.Config
   183  		if redisUseTLS {
   184  			tlsConfig = &tls.Config{}
   185  			if redisClientCertificate != "" {
   186  				clientCert, err := tls.LoadX509KeyPair(redisClientCertificate, redisClientKey)
   187  				if err != nil {
   188  					return nil, err
   189  				}
   190  				tlsConfig.Certificates = []tls.Certificate{clientCert}
   191  			}
   192  			switch {
   193  			case insecureRedis:
   194  				tlsConfig.InsecureSkipVerify = true
   195  			case redisCACertificate != "":
   196  				redisCA, err := certutil.ParseTLSCertificatesFromPath(redisCACertificate)
   197  				if err != nil {
   198  					return nil, err
   199  				}
   200  				tlsConfig.RootCAs = certutil.GetCertPoolFromPEMData(redisCA)
   201  			default:
   202  				var err error
   203  				tlsConfig.RootCAs, err = x509.SystemCertPool()
   204  				if err != nil {
   205  					return nil, err
   206  				}
   207  			}
   208  		}
   209  		password := os.Getenv(envRedisPassword)
   210  		username := os.Getenv(envRedisUsername)
   211  		sentinelUsername := os.Getenv(envRedisSentinelUsername)
   212  		sentinelPassword := os.Getenv(envRedisSentinelPassword)
   213  		if opt.FlagPrefix != "" {
   214  			if val := os.Getenv(opt.getEnvPrefix() + envRedisUsername); val != "" {
   215  				username = val
   216  			}
   217  			if val := os.Getenv(opt.getEnvPrefix() + envRedisPassword); val != "" {
   218  				password = val
   219  			}
   220  			if val := os.Getenv(opt.getEnvPrefix() + envRedisSentinelUsername); val != "" {
   221  				sentinelUsername = val
   222  			}
   223  			if val := os.Getenv(opt.getEnvPrefix() + envRedisSentinelPassword); val != "" {
   224  				sentinelPassword = val
   225  			}
   226  		}
   227  
   228  		maxRetries := env.ParseNumFromEnv(envRedisRetryCount, defaultRedisRetryCount, 0, math.MaxInt32)
   229  		compression, err := CompressionTypeFromString(compressionStr)
   230  		if err != nil {
   231  			return nil, err
   232  		}
   233  		if len(sentinelAddresses) > 0 {
   234  			client := buildFailoverRedisClient(sentinelMaster, sentinelUsername, sentinelPassword, password, username, redisDB, maxRetries, tlsConfig, sentinelAddresses)
   235  			opt.callOnClientCreated(client)
   236  			return NewCache(NewRedisCache(client, defaultCacheExpiration, compression)), nil
   237  		}
   238  		if redisAddress == "" {
   239  			redisAddress = common.DefaultRedisAddr
   240  		}
   241  
   242  		client := buildRedisClient(redisAddress, password, username, redisDB, maxRetries, tlsConfig)
   243  		opt.callOnClientCreated(client)
   244  		return NewCache(NewRedisCache(client, defaultCacheExpiration, compression)), nil
   245  	}
   246  }
   247  
   248  // Cache provides strongly types methods to store and retrieve values from shared cache
   249  type Cache struct {
   250  	client CacheClient
   251  }
   252  
   253  func (c *Cache) GetClient() CacheClient {
   254  	return c.client
   255  }
   256  
   257  func (c *Cache) SetClient(client CacheClient) {
   258  	c.client = client
   259  }
   260  
   261  func (c *Cache) RenameItem(oldKey string, newKey string, expiration time.Duration) error {
   262  	return c.client.Rename(fmt.Sprintf("%s|%s", oldKey, common.CacheVersion), fmt.Sprintf("%s|%s", newKey, common.CacheVersion), expiration)
   263  }
   264  
   265  func (c *Cache) generateFullKey(key string) string {
   266  	if key == "" {
   267  		log.Debug("Cache key is empty, this will result in key collisions if there is more than one empty key")
   268  	}
   269  	return fmt.Sprintf("%s|%s", key, common.CacheVersion)
   270  }
   271  
   272  // Sets or deletes an item in cache
   273  func (c *Cache) SetItem(key string, item any, opts *CacheActionOpts) error {
   274  	if item == nil {
   275  		return errors.New("cannot set nil item in cache")
   276  	}
   277  	if opts == nil {
   278  		opts = &CacheActionOpts{}
   279  	}
   280  	fullKey := c.generateFullKey(key)
   281  	client := c.GetClient()
   282  	if opts.Delete {
   283  		return client.Delete(fullKey)
   284  	}
   285  	return client.Set(&Item{Key: fullKey, Object: item, CacheActionOpts: *opts})
   286  }
   287  
   288  func (c *Cache) GetItem(key string, item any) error {
   289  	key = c.generateFullKey(key)
   290  	if item == nil {
   291  		return fmt.Errorf("cannot get item into a nil for key %s", key)
   292  	}
   293  	client := c.GetClient()
   294  	return client.Get(key, item)
   295  }
   296  
   297  func (c *Cache) OnUpdated(ctx context.Context, key string, callback func() error) error {
   298  	return c.client.OnUpdated(ctx, c.generateFullKey(key), callback)
   299  }
   300  
   301  func (c *Cache) NotifyUpdated(key string) error {
   302  	return c.client.NotifyUpdated(c.generateFullKey(key))
   303  }