go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/gaeemulation/rediscache.go (about)

     1  // Copyright 2020 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gaeemulation
    16  
    17  import (
    18  	"context"
    19  	"crypto/sha1"
    20  	"encoding/hex"
    21  	"fmt"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/gomodule/redigo/redis"
    26  	"go.opentelemetry.io/otel"
    27  	"go.opentelemetry.io/otel/codes"
    28  
    29  	"go.chromium.org/luci/common/errors"
    30  	"go.chromium.org/luci/common/logging"
    31  	"go.chromium.org/luci/gae/filter/dscache"
    32  )
    33  
    34  var tracer = otel.Tracer("go.chromium.org/luci/server/gaeemulation")
    35  
    36  const (
    37  	lockPrefix = 'L' // items that hold locks start with this byte
    38  	dataPrefix = 'D' // items that hold data start with this byte
    39  
    40  	// A prefix byte +  nonce.
    41  	maxLockItemLen = 1 + dscache.NonceBytes
    42  )
    43  
    44  // To avoid allocations in Prefix() below.
    45  var dataPrefixBuf = []byte{dataPrefix}
    46  
    47  // The script does "compare prefix and swap".
    48  //
    49  // Arguments:
    50  //
    51  //	KEYS[1]: the key to operate on.
    52  //	ARGV[1]: the old value to compare to (its first maxLockItemLen bytes).
    53  //	ARGV[2]: the new value to write.
    54  //	ARGV[3]: expiration time (in sec) of the new value.
    55  var casScript = strings.TrimSpace(fmt.Sprintf(`
    56  if redis.call("GETRANGE", KEYS[1], 0, %d) == ARGV[1] then
    57  	return redis.call("SET", KEYS[1], ARGV[2], "EX", ARGV[3])
    58  end
    59  `, maxLockItemLen))
    60  
    61  // casScriptSHA1 is SHA1 of `casScript`, to be used with EVALSHA to save on
    62  // a round trip to redis per CAS.
    63  var casScriptSHA1 string
    64  
    65  func init() {
    66  	dgst := sha1.Sum([]byte(casScript))
    67  	casScriptSHA1 = hex.EncodeToString(dgst[:])
    68  }
    69  
    70  // redisCache implements dscache.Cache via Redis.
    71  type redisCache struct {
    72  	pool *redis.Pool
    73  }
    74  
    75  func (c redisCache) do(ctx context.Context, op string, cb func(conn redis.Conn) error) (err error) {
    76  	ctx, span := tracer.Start(ctx, "go.chromium.org/luci/server/redisCache."+op)
    77  	defer func() {
    78  		if err != nil {
    79  			span.RecordError(err)
    80  			span.SetStatus(codes.Error, err.Error())
    81  		}
    82  		span.End()
    83  	}()
    84  
    85  	conn, err := c.pool.GetContext(ctx)
    86  	if err != nil {
    87  		return errors.Annotate(err, "dscache %s", op).Err()
    88  	}
    89  	defer conn.Close()
    90  
    91  	if err = cb(conn); err != nil {
    92  		return errors.Annotate(err, "dscache %s", op).Err()
    93  	}
    94  	return nil
    95  }
    96  
    97  func (c redisCache) PutLocks(ctx context.Context, keys []string, timeout time.Duration) error {
    98  	if len(keys) == 0 {
    99  		return nil
   100  	}
   101  	return c.do(ctx, "PutLocks", func(conn redis.Conn) error {
   102  		for _, key := range keys {
   103  			conn.Send("SET", key, []byte{lockPrefix}, "EX", int(timeout.Seconds()))
   104  		}
   105  		_, err := conn.Do("")
   106  		return err
   107  	})
   108  }
   109  
   110  func (c redisCache) DropLocks(ctx context.Context, keys []string) error {
   111  	if len(keys) == 0 {
   112  		return nil
   113  	}
   114  	return c.do(ctx, "DropLocks", func(conn redis.Conn) error {
   115  		for _, key := range keys {
   116  			conn.Send("DEL", key)
   117  		}
   118  		_, err := conn.Do("")
   119  		return err
   120  	})
   121  }
   122  
   123  func (c redisCache) TryLockAndFetch(ctx context.Context, keys []string, nonce []byte, timeout time.Duration) ([]dscache.CacheItem, error) {
   124  	if len(keys) == 0 {
   125  		return nil, nil
   126  	}
   127  
   128  	// Prepopulate the response with nil items which mean "cache miss". It is
   129  	// always safe to return them, the dscache will fallback to using datastore
   130  	// (without touching the cache in the end).
   131  	items := make([]dscache.CacheItem, len(keys))
   132  
   133  	err := c.do(ctx, "TryLockAndFetch", func(conn redis.Conn) (err error) {
   134  		// Send a pipeline of SET NX+GET pairs.
   135  		prefixedNonce := append([]byte{lockPrefix}, nonce...)
   136  		for _, key := range keys {
   137  			if key == "" {
   138  				continue
   139  			}
   140  			conn.Send("SET", key, prefixedNonce, "NX", "EX", int(timeout.Seconds()))
   141  			conn.Send("GET", key)
   142  		}
   143  		conn.Flush()
   144  
   145  		// Parse replies.
   146  		for i, key := range keys {
   147  			if key == "" {
   148  				continue
   149  			}
   150  			conn.Receive() // skip the result of "SET", we want "GET"
   151  			if body, err := redis.Bytes(conn.Receive()); err == nil {
   152  				items[i] = &cacheItem{key: key, body: body}
   153  			}
   154  			if conn.Err() != nil {
   155  				return conn.Err() // the connection is dropped, can't fetch the rest
   156  			}
   157  		}
   158  		return nil
   159  	})
   160  
   161  	return items, err
   162  }
   163  
   164  func (c redisCache) CompareAndSwap(ctx context.Context, items []dscache.CacheItem) error {
   165  	if len(items) == 0 {
   166  		return nil
   167  	}
   168  
   169  	return c.do(ctx, "CompareAndSwap", func(conn redis.Conn) error {
   170  		toSwap := make([]*cacheItem, len(items))
   171  		for i, item := range items {
   172  			item := item.(*cacheItem)
   173  			if item.lock == nil {
   174  				panic("dscache violated Cache contract: can CAS only promoted items")
   175  			}
   176  			toSwap[i] = item
   177  		}
   178  
   179  		for {
   180  			// Pipeline all CAS operations at once.
   181  			for _, item := range toSwap {
   182  				_ = conn.Send("EVALSHA",
   183  					casScriptSHA1, // the script to execute
   184  					1,             // number of key-typed arguments (see casScript)
   185  					item.key,      // the key to operate on
   186  					item.lock,     // will be compared to what's in the cache right now
   187  					item.body,     // the new value if comparison succeeds
   188  					int(item.exp.Seconds()),
   189  				)
   190  			}
   191  
   192  			// Flush and read results. Here `err` is a connection-level error and
   193  			// `batchReply` is secretly an array of EVALSHA replies (some of which can
   194  			// be redis.Error).
   195  			batchReply, err := conn.Do("")
   196  			if err != nil {
   197  				return err
   198  			}
   199  
   200  			replies := batchReply.([]any)
   201  			if len(replies) != len(toSwap) {
   202  				panic(fmt.Sprintf("Redis protocol violation: %d != %d", len(replies), len(toSwap)))
   203  			}
   204  
   205  			// If we get a NOSCRIPT error, need to load the CAS script and redo
   206  			// operations that failed. Any other error is non-recoverable.
   207  			toRetry := toSwap[:0]
   208  			for i, rep := range replies {
   209  				if err, isErr := rep.(redis.Error); isErr {
   210  					if strings.HasPrefix(err.Error(), "NOSCRIPT ") {
   211  						toRetry = append(toRetry, toSwap[i])
   212  					} else {
   213  						return err
   214  					}
   215  				}
   216  			}
   217  			if len(toRetry) == 0 {
   218  				return nil
   219  			}
   220  			toSwap = toRetry
   221  
   222  			// Redis doesn't know about the script yet. Load it and retry EVALSHAs.
   223  			// This should happen very rarely (in theory only after Redis server
   224  			// restarts or full flushes).
   225  			logging.Warningf(ctx, "Loading the CAS script into Redis")
   226  			if _, err = conn.Do("SCRIPT", "LOAD", casScript); err != nil {
   227  				return err
   228  			}
   229  		}
   230  	})
   231  }
   232  
   233  ////////////////////////////////////////////////////////////////////////////////
   234  
   235  type cacheItem struct {
   236  	key  string
   237  	body []byte
   238  
   239  	lock []byte        // set to the previous value of `body` after the promotion
   240  	exp  time.Duration // set after the promotion
   241  }
   242  
   243  func (ci *cacheItem) Key() string {
   244  	return ci.key
   245  }
   246  
   247  func (ci *cacheItem) Nonce() []byte {
   248  	if len(ci.body) > 0 && ci.body[0] == lockPrefix {
   249  		return ci.body[1:]
   250  	}
   251  	return nil
   252  }
   253  
   254  func (ci *cacheItem) Data() []byte {
   255  	if len(ci.body) > 0 && ci.body[0] == dataPrefix {
   256  		return ci.body[1:]
   257  	}
   258  	return nil
   259  }
   260  
   261  func (ci *cacheItem) Prefix() []byte {
   262  	return dataPrefixBuf
   263  }
   264  
   265  func (ci *cacheItem) PromoteToData(data []byte, exp time.Duration) {
   266  	if len(data) == 0 || data[0] != dataPrefix {
   267  		panic("dscache violated CacheItem contract: data is not prefixed by Prefix()")
   268  	}
   269  	ci.promote(data, exp)
   270  }
   271  
   272  func (ci *cacheItem) PromoteToIndefiniteLock() {
   273  	ci.promote([]byte{lockPrefix}, time.Hour*24*30)
   274  }
   275  
   276  func (ci *cacheItem) promote(body []byte, exp time.Duration) {
   277  	if ci.lock != nil {
   278  		panic("already promoted")
   279  	}
   280  	if len(ci.body) == 0 || ci.body[0] != lockPrefix {
   281  		panic("not a lock item")
   282  	}
   283  	ci.lock = ci.body
   284  	// Note: this should not normally happen, but may happen if some items were
   285  	// written with different value of dscache.NonceBytes constant. We need to
   286  	// trim it for the casScript that compares only up to maxLockItemLen bytes.
   287  	if len(ci.lock) > maxLockItemLen {
   288  		ci.lock = ci.lock[:maxLockItemLen]
   289  	}
   290  	ci.body = body
   291  	ci.exp = exp
   292  }