go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/resultdb/internal/invocations/graph/graph.go (about)

     1  // Copyright 2022 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package graph
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"time"
    21  
    22  	"cloud.google.com/go/spanner"
    23  	"github.com/gomodule/redigo/redis"
    24  	"go.opentelemetry.io/otel/attribute"
    25  	"golang.org/x/sync/errgroup"
    26  	"google.golang.org/protobuf/proto"
    27  
    28  	"go.chromium.org/luci/common/errors"
    29  	"go.chromium.org/luci/common/logging"
    30  	"go.chromium.org/luci/resultdb/internal/invocations"
    31  	"go.chromium.org/luci/resultdb/internal/spanutil"
    32  	"go.chromium.org/luci/resultdb/internal/tracing"
    33  	pb "go.chromium.org/luci/resultdb/proto/v1"
    34  	"go.chromium.org/luci/server/redisconn"
    35  	"go.chromium.org/luci/server/span"
    36  )
    37  
    38  // MaxNodes is the maximum number of invocation nodes that ResultDB
    39  // can operate on at a time.
    40  const MaxNodes = 20000
    41  
    42  // reachCacheExpiration is expiration duration of ReachCache.
    43  // It is more important to have *some* expiration; the value itself matters less
    44  // because Redis evicts LRU keys only with *some* expiration set,
    45  // see volatile-lru policy: https://redis.io/topics/lru-cache
    46  const reachCacheExpiration = 30 * 24 * time.Hour // 30 days
    47  
    48  // TooManyTag set in an error indicates that too many invocations
    49  // matched a condition.
    50  var TooManyTag = errors.BoolTag{
    51  	Key: errors.NewTagKey("too many matching invocations matched the condition"),
    52  }
    53  
    54  // Reachable returns all invocations reachable from roots along the inclusion
    55  // edges. May return an appstatus-annotated error.
    56  func Reachable(ctx context.Context, roots invocations.IDSet) (ReachableInvocations, error) {
    57  	invs, err := reachable(ctx, roots, false)
    58  	if err != nil {
    59  		return ReachableInvocations{}, err
    60  	}
    61  	return invs, nil
    62  }
    63  
    64  // ReachableSkipRootCache is similar to BatchedReachable, but it ignores cache
    65  // for the roots.
    66  //
    67  // Useful to keep cache-hit stats high in cases where the roots are known not to
    68  // have cache.
    69  func ReachableSkipRootCache(ctx context.Context, roots invocations.IDSet) (ReachableInvocations, error) {
    70  	invs, err := reachable(ctx, roots, false)
    71  	if err != nil {
    72  		return ReachableInvocations{}, err
    73  	}
    74  	return invs, nil
    75  }
    76  
    77  func reachable(ctx context.Context, roots invocations.IDSet, useRootCache bool) (reachable ReachableInvocations, err error) {
    78  	reachable = NewReachableInvocations()
    79  	uncachedRoots := invocations.NewIDSet()
    80  	if useRootCache {
    81  		for id := range roots {
    82  			// First check the cache.
    83  			switch reachables, err := reachCache(id).Read(ctx); {
    84  			case err == redisconn.ErrNotConfigured || err == ErrUnknownReach:
    85  				// Ignore this error.
    86  				uncachedRoots.Add(id)
    87  			case err != nil:
    88  				logging.Warningf(ctx, "ReachCache: failed to read %s: %s", id, err)
    89  				uncachedRoots.Add(id)
    90  			default:
    91  				// Cache hit. Copy the results to `reachable`.
    92  				reachable.Union(reachables)
    93  			}
    94  		}
    95  	} else {
    96  		uncachedRoots.Union(roots)
    97  	}
    98  
    99  	if len(uncachedRoots) == 0 {
   100  		return reachable, nil
   101  	}
   102  
   103  	uncachedReachable, err := reachableUncached(ctx, uncachedRoots)
   104  	if err != nil {
   105  		return ReachableInvocations{}, err
   106  	}
   107  	reachable.Union(uncachedReachable)
   108  
   109  	// If we queried for one root and we had a cache miss, try to insert the
   110  	// reachable invocations, so that the cache will hopefully be populated
   111  	// next time.
   112  	if len(uncachedRoots) == 1 {
   113  		var root invocations.ID
   114  		for id := range uncachedRoots {
   115  			root = id
   116  		}
   117  		state, err := invocations.ReadState(ctx, root)
   118  		if err != nil {
   119  			logging.Warningf(ctx, "reachable: failed to read root invocation %s: %s", root, err)
   120  		} else if state == pb.Invocation_FINALIZED {
   121  			// Only populate the cache if the invocation exists and is
   122  			// finalized.
   123  			reachCache(root).TryWrite(ctx, uncachedReachable)
   124  		}
   125  	}
   126  
   127  	logging.Debugf(ctx, "%d invocations are reachable from %s", len(reachable.Invocations), roots.Names())
   128  	return reachable, nil
   129  }
   130  
   131  // reachableUncached queries the Spanner database for the reachability graph if the data is not in the reach cache.
   132  func reachableUncached(ctx context.Context, roots invocations.IDSet) (ri ReachableInvocations, err error) {
   133  	ctx, ts := tracing.Start(ctx, "resultdb.graph.reachable")
   134  	defer func() { tracing.End(ts, err) }()
   135  
   136  	reachableInvocations := invocations.NewIDSet()
   137  	reachableInvocations.Union(roots)
   138  
   139  	// Stores a mapping from reachable invocations to the invocation
   140  	// they were included by. Roots are not captured.
   141  	// If the same invocation is included by two or more invocations,
   142  	// only one of them is recorded as the parent. The exact
   143  	// parent invocation selected (if multiple are possible) is not
   144  	// defined, but it is guaranteed that following parents will
   145  	// eventually lead to a root (i.e. there are no cycles in the
   146  	// parent graph).
   147  	reachableInvocationToParent := make(map[invocations.ID]invocations.ID)
   148  
   149  	// Find all reachable invocations traversing the graph one level at a time.
   150  	nextLevel := invocations.NewIDSet()
   151  	nextLevel.Union(roots)
   152  	for len(nextLevel) > 0 {
   153  		includedInvs, err := queryIncludedInvocations(ctx, nextLevel)
   154  		if err != nil {
   155  			return ReachableInvocations{}, err
   156  		}
   157  
   158  		nextLevel = invocations.NewIDSet()
   159  		for inv, invParent := range includedInvs {
   160  			// Avoid duplicate lookups and cycles.
   161  			if _, ok := reachableInvocations[inv]; ok {
   162  				continue
   163  			}
   164  
   165  			nextLevel.Add(inv)
   166  			reachableInvocations.Add(inv)
   167  			reachableInvocationToParent[inv] = invParent
   168  		}
   169  	}
   170  
   171  	var withTestResults invocations.IDSet
   172  	var withExonerations invocations.IDSet
   173  	var invDetails map[invocations.ID]invocationDetails
   174  
   175  	eg, ctx := errgroup.WithContext(ctx)
   176  	eg.Go(func() error {
   177  		var err error
   178  		withTestResults, err = queryInvocations(ctx, `SELECT DISTINCT tr.InvocationID FROM UNNEST(@invocations) inv JOIN TestResults tr on tr.InvocationId = inv`, reachableInvocations)
   179  		if err != nil {
   180  			return errors.Annotate(err, "querying invocations with test results").Err()
   181  		}
   182  		return nil
   183  	})
   184  	eg.Go(func() error {
   185  		var err error
   186  		withExonerations, err = queryInvocations(ctx, `SELECT DISTINCT te.InvocationID FROM UNNEST(@invocations) inv JOIN TestExonerations te on te.InvocationId = inv`, reachableInvocations)
   187  		if err != nil {
   188  			return errors.Annotate(err, "querying invocations with test exonerations").Err()
   189  		}
   190  		return nil
   191  	})
   192  	eg.Go(func() error {
   193  		var err error
   194  		invDetails, err = queryInvocationDetails(ctx, reachableInvocations)
   195  		if err != nil {
   196  			return errors.Annotate(err, "querying realms of reachable invocations").Err()
   197  		}
   198  		return nil
   199  	})
   200  	if err := eg.Wait(); err != nil {
   201  		return ReachableInvocations{}, err
   202  	}
   203  
   204  	// Limit the returned reachable invocations to those that exist in the
   205  	// Invocations table; they will have a realm.
   206  	invocations := make(map[invocations.ID]ReachableInvocation, len(reachableInvocations))
   207  	distinctSources := make(map[SourceHash]*pb.Sources)
   208  	for id, details := range invDetails {
   209  		inv := ReachableInvocation{
   210  			HasTestResults:      withTestResults.Has(id),
   211  			HasTestExonerations: withExonerations.Has(id),
   212  			Realm:               details.Realm,
   213  		}
   214  
   215  		sources := resolveSources(id, reachableInvocationToParent, invDetails)
   216  		if sources != nil {
   217  			sourceHash := HashSources(sources)
   218  			distinctSources[sourceHash] = sources
   219  			inv.SourceHash = sourceHash
   220  		}
   221  		invocations[id] = inv
   222  	}
   223  	return ReachableInvocations{
   224  		Invocations: invocations,
   225  		Sources:     distinctSources,
   226  	}, nil
   227  }
   228  
   229  // resolveSources resolves the sources tested by the given invocation.
   230  func resolveSources(id invocations.ID, invToParent map[invocations.ID]invocations.ID, invToDetails map[invocations.ID]invocationDetails) *pb.Sources {
   231  	// If the invocation specifies that it inherits sources,
   232  	// walk the invocation graph back towards the root to
   233  	// resolve the sources.
   234  	invID := id
   235  	details := invToDetails[invID]
   236  	for details.InheritSources {
   237  		var ok bool
   238  		invID, ok = invToParent[invID]
   239  		if !ok {
   240  			// We have walked all the way back to the root,
   241  			// and even the root indicates it is inheriting sources.
   242  			// The actual sources cannot be resolved.
   243  			return nil
   244  		}
   245  		details = invToDetails[invID]
   246  	}
   247  	if details.Sources != nil {
   248  		// Sources found.
   249  		return details.Sources
   250  	}
   251  	// The invocation we inheriting sources from
   252  	// has no sources.
   253  	return nil
   254  }
   255  
   256  type invocationDetails struct {
   257  	Realm          string
   258  	InheritSources bool
   259  	Sources        *pb.Sources
   260  }
   261  
   262  // queryInvocationDetails reads realm and source information
   263  // for the given list of invocations.
   264  func queryInvocationDetails(ctx context.Context, ids invocations.IDSet) (map[invocations.ID]invocationDetails, error) {
   265  	st := spanner.NewStatement(`
   266  		SELECT
   267  			i.InvocationId,
   268  			i.Realm,
   269  			i.InheritSources,
   270  			i.Sources,
   271  		FROM UNNEST(@invIDs) inv
   272  		JOIN Invocations i
   273  		ON i.InvocationId = inv`)
   274  	st.Params = spanutil.ToSpannerMap(map[string]any{
   275  		"invIDs": ids,
   276  	})
   277  	b := &spanutil.Buffer{}
   278  	results := make(map[invocations.ID]invocationDetails)
   279  	err := spanutil.Query(ctx, st, func(r *spanner.Row) error {
   280  		var invocationID invocations.ID
   281  		var realm spanner.NullString
   282  		var inheritSources spanner.NullBool
   283  		var sources spanutil.Compressed
   284  		if err := b.FromSpanner(r, &invocationID, &realm, &inheritSources, &sources); err != nil {
   285  			return err
   286  		}
   287  		var sourcesProto *pb.Sources
   288  		if len(sources) > 0 {
   289  			sourcesProto = &pb.Sources{}
   290  			err := proto.Unmarshal(sources, sourcesProto)
   291  			if err != nil {
   292  				return err
   293  			}
   294  		}
   295  		results[invocationID] = invocationDetails{
   296  			Realm:          realm.StringVal,
   297  			InheritSources: inheritSources.Valid && inheritSources.Bool,
   298  			Sources:        sourcesProto,
   299  		}
   300  		return nil
   301  	})
   302  	if err != nil {
   303  		return nil, err
   304  	}
   305  	return results, nil
   306  }
   307  
   308  // queryIncludedInvocations returns the set of invocations
   309  // included from invocations `ids`, as well as the invocation
   310  // they were included from.
   311  //
   312  // The returned map has a key for each included invocation.
   313  // The value corresponding to the key is the parent invocation.
   314  //
   315  // The same invocation can be included from multiple invocations,
   316  // i.e. there are multiple parents, then the parent in the map
   317  // is selected arbitrarily.
   318  func queryIncludedInvocations(ctx context.Context, ids invocations.IDSet) (map[invocations.ID]invocations.ID, error) {
   319  	st := spanner.NewStatement(`
   320  	SELECT
   321  		ii.InvocationID,
   322  		ii.IncludedInvocationID,
   323  	FROM
   324  		UNNEST(@invocations) inv
   325  		JOIN IncludedInvocations ii ON inv = ii.InvocationID`)
   326  	st.Params = spanutil.ToSpannerMap(spanutil.ToSpannerMap(map[string]any{
   327  		"invocations": ids,
   328  	}))
   329  	results := make(map[invocations.ID]invocations.ID)
   330  
   331  	b := &spanutil.Buffer{}
   332  	err := span.Query(ctx, st).Do(func(r *spanner.Row) error {
   333  		var invocationID invocations.ID
   334  		var includedInvocationID invocations.ID
   335  		if err := b.FromSpanner(r, &invocationID, &includedInvocationID); err != nil {
   336  			return err
   337  		}
   338  		if includingInvocationID, ok := results[includedInvocationID]; ok {
   339  			// If this invocation was included via multiple paths,
   340  			// keep just the one with the lexicographically first
   341  			// invocation ID.
   342  			if invocationID < includingInvocationID {
   343  				results[includedInvocationID] = invocationID
   344  			}
   345  		} else {
   346  			results[includedInvocationID] = invocationID
   347  		}
   348  		return nil
   349  	})
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  	return results, nil
   354  }
   355  
   356  func queryInvocations(ctx context.Context, query string, invocationsParam invocations.IDSet) (invocations.IDSet, error) {
   357  	invs := invocations.NewIDSet()
   358  	st := spanner.NewStatement(query)
   359  	st.Params = spanutil.ToSpannerMap(spanutil.ToSpannerMap(map[string]any{
   360  		"invocations": invocationsParam,
   361  	}))
   362  	b := &spanutil.Buffer{}
   363  	err := span.Query(ctx, st).Do(func(r *spanner.Row) error {
   364  		var invocationID invocations.ID
   365  		if err := b.FromSpanner(r, &invocationID); err != nil {
   366  			return err
   367  		}
   368  		invs.Add(invocationID)
   369  		return nil
   370  	})
   371  	return invs, err
   372  }
   373  
   374  // ReachCache is a cache of all invocations reachable from the given
   375  // invocation, stored in Redis. The cached set is either correct or absent.
   376  //
   377  // The cache must be written only after the set of reachable invocations
   378  // becomes immutable, i.e. when the including invocation is finalized.
   379  // This is important to be able to tolerate transient Redis failures
   380  // and avoid a situation where we failed to update the currently stored set,
   381  // ignored the failure and then, after Redis came back online, read the
   382  // stale set.
   383  type reachCache invocations.ID
   384  
   385  // key returns the Redis key.
   386  func (c reachCache) key() string {
   387  	return fmt.Sprintf("reach4:%s", c)
   388  }
   389  
   390  // Write writes the new value.
   391  // The value does not have to include c, this is implied.
   392  func (c reachCache) Write(ctx context.Context, value ReachableInvocations) (err error) {
   393  	ctx, ts := tracing.Start(ctx, "resultdb.reachCache.write",
   394  		attribute.String("id", string(c)),
   395  	)
   396  	defer func() { tracing.End(ts, err) }()
   397  
   398  	// Expect the set of reachable invocations to include the invocation
   399  	// for which the cache entry is.
   400  	if _, ok := value.Invocations[invocations.ID(c)]; !ok {
   401  		return errors.New("value is invalid, does not contain the root invocation itself")
   402  	}
   403  
   404  	conn, err := redisconn.Get(ctx)
   405  	if err != nil {
   406  		return err
   407  	}
   408  	defer conn.Close()
   409  
   410  	key := c.key()
   411  
   412  	marshaled, err := value.marshal()
   413  	if err != nil {
   414  		return errors.Annotate(err, "marshal").Err()
   415  	}
   416  	ts.SetAttributes(attribute.Int("size", len(marshaled)))
   417  
   418  	if err := conn.Send("SET", key, marshaled); err != nil {
   419  		return err
   420  	}
   421  	if err := conn.Send("EXPIRE", key, int(reachCacheExpiration.Seconds())); err != nil {
   422  		return err
   423  	}
   424  	_, err = conn.Do("")
   425  	return err
   426  }
   427  
   428  // TryWrite tries to write the new value. On failure, logs it.
   429  func (c reachCache) TryWrite(ctx context.Context, value ReachableInvocations) {
   430  	switch err := c.Write(ctx, value); {
   431  	case err == redisconn.ErrNotConfigured:
   432  
   433  	case err != nil:
   434  		logging.Warningf(ctx, "ReachCache: failed to write %s: %s", c, err)
   435  	}
   436  }
   437  
   438  // ErrUnknownReach is returned by ReachCache.Read if the cached value is absent.
   439  var ErrUnknownReach = fmt.Errorf("the reachable set is unknown")
   440  
   441  // Read reads the current value.
   442  // Returns ErrUnknownReach if the value is absent.
   443  //
   444  // If err is nil, ids includes c, even if it was not passed in Write().
   445  func (c reachCache) Read(ctx context.Context) (invs ReachableInvocations, err error) {
   446  	ctx, ts := tracing.Start(ctx, "resultdb.reachCache.read",
   447  		attribute.String("id", string(c)),
   448  	)
   449  	defer func() { tracing.End(ts, err) }()
   450  
   451  	conn, err := redisconn.Get(ctx)
   452  	if err != nil {
   453  		return ReachableInvocations{}, err
   454  	}
   455  	defer conn.Close()
   456  
   457  	b, err := redis.Bytes(conn.Do("GET", c.key()))
   458  	switch {
   459  	case err == redis.ErrNil:
   460  		return ReachableInvocations{}, ErrUnknownReach
   461  	case err != nil:
   462  		return ReachableInvocations{}, err
   463  	}
   464  	ts.SetAttributes(attribute.Int("size", len(b)))
   465  
   466  	return unmarshalReachableInvocations(b)
   467  }