github.com/m-lab/locate@v0.17.6/memorystore/client.go (about) 1 package memorystore 2 3 import ( 4 "encoding/json" 5 "time" 6 7 "github.com/gomodule/redigo/redis" 8 "github.com/m-lab/locate/metrics" 9 "github.com/m-lab/locate/static" 10 ) 11 12 const ( 13 // This is a Lua script that will be interpreted by the Redis server. 14 // The key/argument parameters (e.g., KEYS[1]) are passed to the script 15 // when it is invoked in the Put method (e.g., redis.Args{}.Add(...)). 16 // The command used to interpret the script in Redis is the EVAL command. 17 // Its documentation can be found under https://redis.io/commands/eval/. 18 script = `if redis.call('HEXISTS', KEYS[1], ARGV[1]) == 1 19 then return redis.call('HSET', KEYS[1], ARGV[2], ARGV[3]) 20 else error('key not found') 21 end` 22 ) 23 24 // PutOptions defines the parameters that can be used for PUT operations. 25 type PutOptions struct { 26 FieldMustExist string // Specifies a field that must already exist in the entry. 27 WithExpire bool // Specifies whether an expiration should be added to the entry. 28 } 29 30 type client[V any] struct { 31 pool *redis.Pool 32 } 33 34 // NewClient returns a new MemorystoreClient implementation 35 // that reads and writes data in Redis. 36 func NewClient[V any](pool *redis.Pool) *client[V] { 37 return &client[V]{pool} 38 } 39 40 // Put sets a Redis Hash using the `HSET key field value` command. 41 // If the `opts.WithExpire` option is true, it also (re)sets the key's timeout. 42 func (c *client[V]) Put(key string, field string, value redis.Scanner, opts *PutOptions) error { 43 t := time.Now() 44 conn := c.pool.Get() 45 defer conn.Close() 46 47 b, err := json.Marshal(value) 48 if err != nil { 49 metrics.LocateMemorystoreRequestDuration.WithLabelValues("put", field, "marshal error").Observe(time.Since(t).Seconds()) 50 return err 51 } 52 53 if opts.FieldMustExist != "" { 54 args := redis.Args{}.Add(script).Add(1).Add(key).Add(opts.FieldMustExist).Add(field).AddFlat(string(b)) 55 _, err = conn.Do("EVAL", args...) 56 if err != nil { 57 metrics.LocateMemorystoreRequestDuration.WithLabelValues("put", field, "EVAL error").Observe(time.Since(t).Seconds()) 58 return err 59 } 60 } else { 61 args := redis.Args{}.Add(key).Add(field).AddFlat(string(b)) 62 _, err = conn.Do("HSET", args...) 63 if err != nil { 64 metrics.LocateMemorystoreRequestDuration.WithLabelValues("put", field, "HSET error").Observe(time.Since(t).Seconds()) 65 return err 66 } 67 } 68 69 if !opts.WithExpire { 70 metrics.LocateMemorystoreRequestDuration.WithLabelValues("put", field, "OK").Observe(time.Since(t).Seconds()) 71 return nil 72 } 73 74 _, err = conn.Do("EXPIRE", key, static.RedisKeyExpirySecs) 75 if err != nil { 76 metrics.LocateMemorystoreRequestDuration.WithLabelValues("put", field, "EXPIRE error").Observe(time.Since(t).Seconds()) 77 return err 78 } 79 80 metrics.LocateMemorystoreRequestDuration.WithLabelValues("put", field+" with expiration", "OK").Observe(time.Since(t).Seconds()) 81 return nil 82 } 83 84 // Del removes a key from Redis using the `DEL key` command. 85 func (c *client[V]) Del(key string) error { 86 t := time.Now() 87 conn := c.pool.Get() 88 defer conn.Close() 89 90 _, err := conn.Do("DEL", key) 91 if err != nil { 92 metrics.LocateMemorystoreRequestDuration.WithLabelValues("del", "", "DEL error").Observe(time.Since(t).Seconds()) 93 return err 94 } 95 96 metrics.LocateMemorystoreRequestDuration.WithLabelValues("del", "", "OK").Observe(time.Since(t).Seconds()) 97 return nil 98 } 99 100 // GetAll uses the SCAN command to iterate over all the entries in Redis 101 // and returns a mapping of all the keys to their values. 102 // It implements an "all or nothing" approach in which it will only 103 // return the entries if all of them are scanned successfully. 104 // Otherwise, it will return an error. 105 func (c *client[V]) GetAll() (map[string]V, error) { 106 t := time.Now() 107 conn := c.pool.Get() 108 defer conn.Close() 109 110 values := make(map[string]V) 111 iter := 0 112 113 for { 114 keys, err := redis.Values(conn.Do("SCAN", iter)) 115 if err != nil { 116 metrics.LocateMemorystoreRequestDuration.WithLabelValues("get", "all", "SCAN error").Observe(time.Since(t).Seconds()) 117 return nil, err 118 } 119 120 var temp []string 121 keys, err = redis.Scan(keys, &iter, &temp) 122 if err != nil { 123 metrics.LocateMemorystoreRequestDuration.WithLabelValues("get", "all", "SCAN copy error").Observe(time.Since(t).Seconds()) 124 return nil, err 125 } 126 127 for _, k := range temp { 128 v, err := c.get(k, conn) 129 if err != nil { 130 metrics.LocateMemorystoreRequestDuration.WithLabelValues("get", "all", "HGETALL error").Observe(time.Since(t).Seconds()) 131 return nil, err 132 } 133 values[k] = v 134 } 135 136 if iter == 0 { 137 metrics.LocateMemorystoreRequestDuration.WithLabelValues("get", "all", "OK").Observe(time.Since(t).Seconds()) 138 return values, nil 139 } 140 } 141 } 142 143 func (c *client[V]) get(key string, conn redis.Conn) (V, error) { 144 v := new(V) 145 val, err := redis.Values(conn.Do("HGETALL", key)) 146 if err != nil { 147 return *v, err 148 } 149 150 err = redis.ScanStruct(val, v) 151 if err != nil { 152 return *v, err 153 } 154 155 return *v, nil 156 }