go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/resultdb/internal/services/recorder/batch_create_invocations.go (about)

     1  // Copyright 2020 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package recorder
    16  
    17  import (
    18  	"context"
    19  	"time"
    20  
    21  	"cloud.google.com/go/spanner"
    22  	"google.golang.org/grpc/codes"
    23  
    24  	"go.chromium.org/luci/common/clock"
    25  	"go.chromium.org/luci/common/errors"
    26  	"go.chromium.org/luci/grpc/appstatus"
    27  	"go.chromium.org/luci/resultdb/internal/invocations"
    28  	"go.chromium.org/luci/resultdb/internal/permissions"
    29  	"go.chromium.org/luci/resultdb/internal/spanutil"
    30  	"go.chromium.org/luci/resultdb/internal/tasks"
    31  	"go.chromium.org/luci/resultdb/pbutil"
    32  	pb "go.chromium.org/luci/resultdb/proto/v1"
    33  	"go.chromium.org/luci/server/auth"
    34  	"go.chromium.org/luci/server/span"
    35  )
    36  
    37  // validateBatchCreateInvocationsRequest checks that the individual requests
    38  // are valid, that they match the batch request requestID and that their names
    39  // are not repeated.
    40  // It also returns an IDSet containing the ids of all the invocations to be
    41  // included in the new invocations.
    42  func validateBatchCreateInvocationsRequest(
    43  	now time.Time, reqs []*pb.CreateInvocationRequest, requestID string) (newInvs, includedInvs invocations.IDSet, err error) {
    44  	if err := pbutil.ValidateRequestID(requestID); err != nil {
    45  		return nil, nil, errors.Annotate(err, "request_id").Err()
    46  	}
    47  
    48  	if err := pbutil.ValidateBatchRequestCount(len(reqs)); err != nil {
    49  		return nil, nil, err
    50  	}
    51  
    52  	newInvs = make(invocations.IDSet, len(reqs))
    53  	allIncludedIDs := make(invocations.IDSet)
    54  	for i, req := range reqs {
    55  		if err := validateCreateInvocationRequest(req, now, allIncludedIDs); err != nil {
    56  			return nil, nil, errors.Annotate(err, "requests[%d]", i).Err()
    57  		}
    58  
    59  		// If there's multiple `CreateInvocationRequest`s their request id
    60  		// must either be empty or match the one in the batch request.
    61  		if req.RequestId != "" && req.RequestId != requestID {
    62  			return nil, nil, errors.Reason("requests[%d].request_id: %q does not match request_id %q", i, requestID, req.RequestId).Err()
    63  		}
    64  
    65  		invID := invocations.ID(req.InvocationId)
    66  		if newInvs.Has(invID) {
    67  			return nil, nil, errors.Reason("requests[%d].invocation_id: duplicated invocation id %q", i, req.InvocationId).Err()
    68  		}
    69  		newInvs.Add(invID)
    70  	}
    71  
    72  	return newInvs, allIncludedIDs, nil
    73  }
    74  
    75  // BatchCreateInvocations implements pb.RecorderServer.
    76  func (s *recorderServer) BatchCreateInvocations(ctx context.Context, in *pb.BatchCreateInvocationsRequest) (*pb.BatchCreateInvocationsResponse, error) {
    77  	now := clock.Now(ctx).UTC()
    78  	for i, r := range in.Requests {
    79  		if err := verifyCreateInvocationPermissions(ctx, r); err != nil {
    80  			return nil, errors.Annotate(err, "requests[%d]", i).Err()
    81  		}
    82  
    83  	}
    84  
    85  	idSet, includedInvs, err := validateBatchCreateInvocationsRequest(now, in.Requests, in.RequestId)
    86  	if err != nil {
    87  		return nil, appstatus.BadRequest(err)
    88  	}
    89  
    90  	includedInvs.RemoveAll(idSet)
    91  	if err := permissions.VerifyInvocations(span.Single(ctx), includedInvs, permIncludeInvocation); err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	invs, tokens, err := s.createInvocations(ctx, in.Requests, in.RequestId, now, idSet)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	return &pb.BatchCreateInvocationsResponse{Invocations: invs, UpdateTokens: tokens}, nil
   100  }
   101  
   102  // createInvocations is a shared implementation for CreateInvocation and BatchCreateInvocations RPCs.
   103  func (s *recorderServer) createInvocations(ctx context.Context, reqs []*pb.CreateInvocationRequest, requestID string, now time.Time, idSet invocations.IDSet) ([]*pb.Invocation, []string, error) {
   104  	createdBy := string(auth.CurrentIdentity(ctx))
   105  	ms := s.createInvocationsRequestsToMutations(ctx, now, reqs, requestID, createdBy)
   106  
   107  	var err error
   108  	deduped := false
   109  	_, err = span.ReadWriteTransaction(ctx, func(ctx context.Context) error {
   110  		deduped, err = deduplicateCreateInvocations(ctx, idSet, requestID, createdBy)
   111  		if err != nil {
   112  			return err
   113  		}
   114  		if !deduped {
   115  			span.BufferWrite(ctx, ms...)
   116  			// Enqueue any finalization tasks in the same transaction.
   117  			for _, req := range reqs {
   118  				if req.Invocation.State == pb.Invocation_FINALIZING {
   119  					tasks.StartInvocationFinalization(ctx, invocations.ID(req.InvocationId), false)
   120  				}
   121  			}
   122  		}
   123  
   124  		return nil
   125  	})
   126  	if err != nil {
   127  		return nil, nil, err
   128  	}
   129  	if !deduped {
   130  		for _, r := range reqs {
   131  			spanutil.IncRowCount(ctx, 1, spanutil.Invocations, spanutil.Inserted, r.Invocation.GetRealm())
   132  		}
   133  	}
   134  
   135  	return getCreatedInvocationsAndUpdateTokens(ctx, idSet, reqs)
   136  }
   137  
   138  // createInvocationsRequestsToMutations computes a database mutation for
   139  // inserting a row for each invocation creation requested.
   140  func (s *recorderServer) createInvocationsRequestsToMutations(ctx context.Context, now time.Time, reqs []*pb.CreateInvocationRequest, requestID, createdBy string) []*spanner.Mutation {
   141  
   142  	ms := make([]*spanner.Mutation, 0, len(reqs))
   143  	// Compute mutations
   144  	for _, req := range reqs {
   145  		newInvState := req.Invocation.GetState()
   146  		if newInvState == pb.Invocation_STATE_UNSPECIFIED {
   147  			newInvState = pb.Invocation_ACTIVE
   148  		}
   149  		if newInvState != pb.Invocation_ACTIVE && newInvState != pb.Invocation_FINALIZING {
   150  			// validateCreateInvocationRequest should have rejected any other states.
   151  			panic("do not create invocations in states other than active or finalizing")
   152  		}
   153  
   154  		// Prepare the invocation we will save to spanner.
   155  		inv := &pb.Invocation{
   156  			Name:             invocations.ID(req.InvocationId).Name(),
   157  			State:            newInvState,
   158  			Deadline:         req.Invocation.GetDeadline(),
   159  			Tags:             req.Invocation.GetTags(),
   160  			BigqueryExports:  req.Invocation.GetBigqueryExports(),
   161  			CreatedBy:        createdBy,
   162  			ProducerResource: req.Invocation.GetProducerResource(),
   163  			Realm:            req.Invocation.GetRealm(),
   164  			Properties:       req.Invocation.GetProperties(),
   165  			SourceSpec:       req.Invocation.GetSourceSpec(),
   166  			BaselineId:       req.Invocation.GetBaselineId(),
   167  		}
   168  
   169  		// Ensure the invocation has a deadline.
   170  		if inv.Deadline == nil {
   171  			inv.Deadline = pbutil.MustTimestampProto(now.Add(defaultInvocationDeadlineDuration))
   172  		}
   173  
   174  		pbutil.NormalizeInvocation(inv)
   175  		// Create a mutation to create the invocation.
   176  		ms = append(ms, spanutil.InsertMap("Invocations", s.rowOfInvocation(ctx, inv, requestID)))
   177  
   178  		// Add any inclusions.
   179  		for _, incName := range req.Invocation.IncludedInvocations {
   180  			ms = append(ms, spanutil.InsertMap("IncludedInvocations", map[string]any{
   181  				"InvocationId":         invocations.ID(req.InvocationId),
   182  				"IncludedInvocationId": invocations.MustParseName(incName),
   183  			}))
   184  		}
   185  	}
   186  	return ms
   187  }
   188  
   189  // getCreatedInvocationsAndUpdateTokens reads the full details of the
   190  // invocations just created in a separate read-only transaction, and
   191  // generates an update token for each.
   192  func getCreatedInvocationsAndUpdateTokens(ctx context.Context, idSet invocations.IDSet, reqs []*pb.CreateInvocationRequest) ([]*pb.Invocation, []string, error) {
   193  	ctx, cancel := span.ReadOnlyTransaction(ctx)
   194  	defer cancel()
   195  
   196  	invMap, err := invocations.ReadBatch(ctx, idSet)
   197  	if err != nil {
   198  		return nil, nil, err
   199  	}
   200  
   201  	// Arrange them in same order as the incoming requests.
   202  	// Ordering is important to match the tokens.
   203  	invs := make([]*pb.Invocation, len(reqs))
   204  	for i, req := range reqs {
   205  		invs[i] = invMap[invocations.ID(req.InvocationId)]
   206  	}
   207  
   208  	tokens, err := generateTokens(ctx, invs)
   209  	if err != nil {
   210  		return nil, nil, err
   211  	}
   212  	return invs, tokens, nil
   213  }
   214  
   215  // deduplicateCreateInvocations checks if the invocations have already been
   216  // created with the given requestID and current requester.
   217  // Returns a true if they have.
   218  func deduplicateCreateInvocations(ctx context.Context, idSet invocations.IDSet, requestID, createdBy string) (bool, error) {
   219  	invCount := 0
   220  	columns := []string{"InvocationId", "CreateRequestId", "CreatedBy"}
   221  	err := span.Read(ctx, "Invocations", idSet.Keys(), columns).Do(func(r *spanner.Row) error {
   222  		var invID invocations.ID
   223  		var rowRequestID spanner.NullString
   224  		var rowCreatedBy spanner.NullString
   225  		switch err := spanutil.FromSpanner(r, &invID, &rowRequestID, &rowCreatedBy); {
   226  		case err != nil:
   227  			return err
   228  		case !rowRequestID.Valid || rowRequestID.StringVal != requestID:
   229  			return invocationAlreadyExists(invID)
   230  		case rowCreatedBy.StringVal != createdBy:
   231  			return invocationAlreadyExists(invID)
   232  		default:
   233  			invCount++
   234  			return nil
   235  		}
   236  	})
   237  	switch {
   238  	case err != nil:
   239  		return false, err
   240  	case invCount == len(idSet):
   241  		// All invocations were previously created with this request id.
   242  		return true, nil
   243  	case invCount == 0:
   244  		// None of the invocations exist already.
   245  		return false, nil
   246  	default:
   247  		// Could happen if someone sent two different but overlapping batch create
   248  		// requests, but reused the request_id.
   249  		return false, appstatus.Errorf(codes.AlreadyExists, "some, but not all of the invocations already created with this request id")
   250  	}
   251  }
   252  
   253  // generateTokens generates an update token for each invocation.
   254  func generateTokens(ctx context.Context, invs []*pb.Invocation) ([]string, error) {
   255  	ret := make([]string, len(invs))
   256  	for i, inv := range invs {
   257  		updateToken, err := generateInvocationToken(ctx, invocations.MustParseName(inv.Name))
   258  		if err != nil {
   259  			return nil, err
   260  		}
   261  		ret[i] = updateToken
   262  	}
   263  	return ret, nil
   264  }