go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/swarming/server/botsrv/botsrv.go (about)

     1  // Copyright 2022 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package botsrv knows how to authenticate calls from Swarming RBE bots.
    16  //
    17  // It checks PollState/BotSession tokens and bot credentials.
    18  package botsrv
    19  
    20  import (
    21  	"context"
    22  	"encoding/json"
    23  	"fmt"
    24  	"io"
    25  	"net/http"
    26  	"strings"
    27  
    28  	"google.golang.org/grpc/codes"
    29  	"google.golang.org/grpc/status"
    30  	"google.golang.org/protobuf/encoding/prototext"
    31  	"google.golang.org/protobuf/proto"
    32  
    33  	"go.chromium.org/luci/auth/identity"
    34  	"go.chromium.org/luci/common/clock"
    35  	"go.chromium.org/luci/common/errors"
    36  	"go.chromium.org/luci/common/logging"
    37  	"go.chromium.org/luci/common/retry/transient"
    38  	"go.chromium.org/luci/grpc/grpcutil"
    39  	"go.chromium.org/luci/server/auth"
    40  	"go.chromium.org/luci/server/auth/openid"
    41  	"go.chromium.org/luci/server/router"
    42  	"go.chromium.org/luci/tokenserver/auth/machine"
    43  
    44  	internalspb "go.chromium.org/luci/swarming/proto/internals"
    45  	"go.chromium.org/luci/swarming/server/hmactoken"
    46  )
    47  
    48  // RequestBody should be implemented by a JSON-serializable struct representing
    49  // format of some particular request.
    50  type RequestBody interface {
    51  	ExtractPollToken() []byte               // the poll token, if present
    52  	ExtractSessionToken() []byte            // the session token, if present
    53  	ExtractDimensions() map[string][]string // dimensions reported by the bot, if present
    54  	ExtractDebugRequest() any               // serialized as JSON and logged on errors
    55  }
    56  
    57  // Request is extracted from an authenticated request from a bot.
    58  type Request struct {
    59  	BotID               string                 // validated bot ID
    60  	SessionID           string                 // validated RBE bot session ID, if present
    61  	SessionTokenExpired bool                   // true if the request has expired session token
    62  	PollState           *internalspb.PollState // validated poll state
    63  	Dimensions          map[string][]string    // validated dimensions
    64  }
    65  
    66  // Response is serialized as JSON and sent to the bot.
    67  type Response any
    68  
    69  // Handler handles an authenticated request from a bot.
    70  //
    71  // It takes a raw deserialized request body and all authenticated data extracted
    72  // from it.
    73  //
    74  // It returns a response that will be serialized and sent to the bot as JSON or
    75  // a gRPC error code that will be converted into an HTTP error.
    76  type Handler[B any] func(ctx context.Context, body *B, req *Request) (Response, error)
    77  
    78  // Server knows how to authenticate bot requests and route them to handlers.
    79  type Server struct {
    80  	router      *router.Router
    81  	middlewares router.MiddlewareChain
    82  	hmacSecret  *hmactoken.Secret
    83  }
    84  
    85  // New constructs new Server.
    86  func New(ctx context.Context, r *router.Router, projectID string, hmacSecret *hmactoken.Secret) *Server {
    87  	gaeAppDomain := fmt.Sprintf("%s.appspot.com", projectID)
    88  	return &Server{
    89  		router: r,
    90  		middlewares: router.MiddlewareChain{
    91  			// All supported bot authentication schemes. The first matching one wins.
    92  			auth.Authenticate(
    93  				// This checks "X-Luci-Gce-Vm-Token" header if present. The token
    94  				// audience should be `[https://][<prefix>-dot-]app.appspot.com`.
    95  				&openid.GoogleComputeAuthMethod{
    96  					Header: "X-Luci-Gce-Vm-Token",
    97  					AudienceCheck: func(_ context.Context, _ auth.RequestMetadata, aud string) (bool, error) {
    98  						aud = strings.TrimPrefix(aud, "https://")
    99  						return aud == gaeAppDomain || strings.HasSuffix(aud, "-dot-"+gaeAppDomain), nil
   100  					},
   101  				},
   102  				// This checks "X-Luci-Machine-Token" header if present.
   103  				&machine.MachineTokenAuthMethod{},
   104  				// This checks "Authorization" header if present.
   105  				&auth.GoogleOAuth2Method{
   106  					Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"},
   107  				},
   108  			),
   109  		},
   110  		hmacSecret: hmacSecret,
   111  	}
   112  }
   113  
   114  // RequestBodyConstraint is needed to make Go generics type checker happy.
   115  type RequestBodyConstraint[B any] interface {
   116  	RequestBody
   117  	*B
   118  }
   119  
   120  // InstallHandler installs a bot request handler at the given route.
   121  func InstallHandler[B any, RB RequestBodyConstraint[B]](s *Server, route string, h Handler[B]) {
   122  	s.router.POST(route, s.middlewares, func(c *router.Context) {
   123  		ctx := c.Request.Context()
   124  		req := c.Request
   125  		wrt := c.Writer
   126  
   127  		// Deserialized request body.
   128  		var body *B
   129  
   130  		// Deserialized and validated tokens in the request.
   131  		var pollTokenState *internalspb.PollState
   132  		var sessionState *internalspb.BotSession
   133  
   134  		// This is either pollTokenState or the poll state inside sessionState,
   135  		// depending on which token is non-expired. Populated below.
   136  		var pollState *internalspb.PollState
   137  
   138  		// writeErr logs a gRPC error and writes it to the HTTP response.
   139  		writeErr := func(err error) {
   140  			// Log request details to help in debugging errors.
   141  			logging.Infof(ctx, "Bot IP: %s", auth.GetState(ctx).PeerIP())
   142  			logging.Infof(ctx, "Authenticated: %s", auth.GetState(ctx).PeerIdentity())
   143  			if pollState != nil {
   144  				logging.Infof(ctx, "Bot ID: %s", extractBotID(pollState))
   145  				logging.Infof(ctx, "Poll token ID: %s", pollState.Id)
   146  				logging.Infof(ctx, "RBE: %s", pollState.RbeInstance)
   147  				if pollState.DebugInfo != nil {
   148  					logging.Infof(ctx, "Poll token age: %s", clock.Now(ctx).Sub(pollState.DebugInfo.Created.AsTime()))
   149  				}
   150  			}
   151  			if sessionState != nil {
   152  				logging.Infof(ctx, "Session ID: %s", sessionState.RbeBotSessionId)
   153  			}
   154  			if body != nil {
   155  				blob, _ := json.MarshalIndent(RB(body).ExtractDebugRequest(), "", "  ")
   156  				logging.Infof(ctx, "Request body:\n%s", blob)
   157  			}
   158  
   159  			// Log the actual error.
   160  			err = grpcutil.GRPCifyAndLogErr(ctx, err)
   161  			statusCode := status.Code(err)
   162  			httpCode := grpcutil.CodeStatus(statusCode)
   163  			if statusCode == codes.Unavailable {
   164  				// UNAVAILABLE seems to happen a lot, but in bursts (probably when the
   165  				// RBE scheduler restarts). Log it at the warning severity to make other
   166  				// errors more noticeable.
   167  				logging.Warningf(ctx, "HTTP %d: %s", httpCode, err)
   168  			} else {
   169  				logging.Errorf(ctx, "HTTP %d: %s", httpCode, err)
   170  			}
   171  
   172  			http.Error(wrt, err.Error(), httpCode)
   173  		}
   174  
   175  		// Deserialize JSON request body.
   176  		if ct := req.Header.Get("Content-Type"); strings.ToLower(ct) != "application/json; charset=utf-8" {
   177  			writeErr(status.Errorf(codes.InvalidArgument, "bad content type %q", ct))
   178  			return
   179  		}
   180  		raw, err := io.ReadAll(req.Body)
   181  		if err != nil {
   182  			writeErr(status.Errorf(codes.Internal, "error reading request body: %s", err))
   183  			return
   184  		}
   185  		body = new(B)
   186  		if err := json.Unmarshal(raw, body); err != nil {
   187  			logging.Warningf(ctx, "Unrecognized request:\n%s", raw)
   188  			writeErr(status.Errorf(codes.InvalidArgument, "failed to deserialized the request: %s", err))
   189  			return
   190  		}
   191  
   192  		// To authenticate the bot we need either a non-expired poll token, a
   193  		// non-expired session token or both (in which case the poll token is
   194  		// preferred, since it should be more recently produced in this case). If we
   195  		// have a poll token, we validate it to directly get PollState. If we have
   196  		// a session token, we validate it and grab PollState from within it. This
   197  		// PollState is then used to check bot credentials.
   198  		//
   199  		// This scheme is necessary because poll tokens can be produced only by
   200  		// Python Swarming server when bot calls "/bot/poll" endpoint. When the bot
   201  		// is running a task, it isn't polling Python Swarming server and its poll
   202  		// token expires. For that reason when running a task (or making other
   203  		// post-task calls that happen before the next poll), we use the session
   204  		// token instead, which has the most recently validated PollState stored in
   205  		// it in a "frozen" state.
   206  		//
   207  		// When the bot is polling for tasks, it sends both poll token and session
   208  		// token to us, which allows us to put up-to-date PollState into the
   209  		// session token. This happens in UpdateBotSession handler.
   210  
   211  		// If have a poll token, validate and deserialize it.
   212  		if pollToken := RB(body).ExtractPollToken(); len(pollToken) != 0 {
   213  			pollTokenState = &internalspb.PollState{}
   214  			if err := s.hmacSecret.ValidateToken(pollToken, pollTokenState); err != nil {
   215  				writeErr(status.Errorf(codes.Unauthenticated, "failed to verify poll token: %s", err))
   216  				return
   217  			}
   218  			if exp := clock.Now(ctx).Sub(pollTokenState.Expiry.AsTime()); exp > 0 {
   219  				logging.Warningf(ctx, "Ignoring poll token (expired %s ago):\n%s", exp, prettyProto(pollTokenState))
   220  				pollTokenState = nil
   221  			}
   222  		}
   223  		// If have a session token, validate and deserialize it as well.
   224  		sessionTokenExpired := false
   225  		if sessionToken := RB(body).ExtractSessionToken(); len(sessionToken) != 0 {
   226  			sessionState = &internalspb.BotSession{}
   227  			if err := s.hmacSecret.ValidateToken(sessionToken, sessionState); err != nil {
   228  				writeErr(status.Errorf(codes.Unauthenticated, "failed to verify session token: %s", err))
   229  				return
   230  			}
   231  			if exp := clock.Now(ctx).Sub(sessionState.Expiry.AsTime()); exp > 0 {
   232  				logging.Warningf(ctx, "Ignoring session token (expired %s ago):\n%s", exp, prettyProto(sessionState))
   233  				sessionState = nil
   234  				sessionTokenExpired = true
   235  			}
   236  		}
   237  
   238  		// Need at least one valid and fresh token.
   239  		if pollTokenState == nil && sessionState == nil {
   240  			writeErr(status.Errorf(codes.Unauthenticated, "no valid poll or state token"))
   241  			return
   242  		}
   243  
   244  		// Prefer the state from the poll token. It is fresher. Fallback to the
   245  		// state stored in the session token if there's no poll token or it has
   246  		// expired.
   247  		pollState = pollTokenState
   248  		if pollState == nil {
   249  			pollState = sessionState.GetPollState()
   250  			if pollState == nil {
   251  				writeErr(status.Errorf(codes.Unauthenticated, "no poll state available"))
   252  				return
   253  			}
   254  		}
   255  
   256  		// Extract bot ID from the validated PollToken.
   257  		botID := extractBotID(pollState)
   258  		if botID == "" {
   259  			writeErr(status.Errorf(codes.InvalidArgument, "no bot ID"))
   260  			return
   261  		}
   262  		// Session ID must be present if there's a session token.
   263  		if sessionState != nil && sessionState.RbeBotSessionId == "" {
   264  			writeErr(status.Errorf(codes.InvalidArgument, "no session ID"))
   265  			return
   266  		}
   267  
   268  		// Verify bot credentials match what's recorded in the validated poll state.
   269  		if err := checkCredentials(ctx, pollState); err != nil {
   270  			if transient.Tag.In(err) {
   271  				writeErr(status.Errorf(codes.Internal, "transient error checking bot credentials: %s", err))
   272  			} else {
   273  				writeErr(status.Errorf(codes.Unauthenticated, "bad bot credentials: %s", err))
   274  			}
   275  			return
   276  		}
   277  
   278  		// Apply verified state stored in PollState on top of whatever was reported
   279  		// by the bot. Normally functioning bots should report the same values as
   280  		// stored in the token.
   281  		dims := RB(body).ExtractDimensions()
   282  		for _, dim := range pollState.EnforcedDimensions {
   283  			reported := dims[dim.Key]
   284  			if !strSliceEq(reported, dim.Values) {
   285  				logging.Errorf(ctx, "Dimension %q mismatch: reported %v, expecting %v",
   286  					dim.Key, reported, dim.Values,
   287  				)
   288  				dims[dim.Key] = dim.Values
   289  			}
   290  		}
   291  
   292  		// There must be `pool` dimension with at least one value (perhaps more).
   293  		if len(dims["pool"]) == 0 {
   294  			writeErr(status.Errorf(codes.InvalidArgument, "no pool dimension"))
   295  			return
   296  		}
   297  
   298  		// The request is valid, dispatch it to the handler.
   299  		resp, err := h(ctx, body, &Request{
   300  			BotID:               botID,
   301  			SessionID:           sessionState.GetRbeBotSessionId(),
   302  			SessionTokenExpired: sessionTokenExpired,
   303  			PollState:           pollState,
   304  			Dimensions:          dims,
   305  		})
   306  		if err != nil {
   307  			writeErr(err)
   308  			return
   309  		}
   310  
   311  		// Success! Write back the response.
   312  		wrt.Header().Set("Content-Type", "application/json; charset=utf-8")
   313  		var werr error
   314  		if resp == nil {
   315  			_, werr = wrt.Write([]byte("{\"ok\": true}\n"))
   316  		} else {
   317  			werr = json.NewEncoder(wrt).Encode(resp)
   318  		}
   319  		if werr != nil {
   320  			logging.Errorf(ctx, "Error writing the response: %s", werr)
   321  		}
   322  	})
   323  }
   324  
   325  // prettyProto formats a proto message for logs.
   326  func prettyProto(msg proto.Message) string {
   327  	blob, err := prototext.MarshalOptions{
   328  		Multiline: true,
   329  		Indent:    "  ",
   330  	}.Marshal(msg)
   331  	if err != nil {
   332  		return fmt.Sprintf("<error: %s>", err)
   333  	}
   334  	return string(blob)
   335  }
   336  
   337  // checkCredentials checks the bot credentials in the context match what is
   338  // required by the PollState.
   339  //
   340  // It ensures the Go portion of the Swarming server authenticates the bot in
   341  // the exact same way the Python portion did (since the Python portion produced
   342  // the PollState after it authenticated the bot).
   343  func checkCredentials(ctx context.Context, pollState *internalspb.PollState) error {
   344  	switch m := pollState.AuthMethod.(type) {
   345  	case *internalspb.PollState_GceAuth:
   346  		gceInfo := openid.GetGoogleComputeTokenInfo(ctx)
   347  		if gceInfo == nil {
   348  			return errors.Reason("expecting GCE VM token auth").Err()
   349  		}
   350  		if gceInfo.Project != m.GceAuth.GceProject || gceInfo.Instance != m.GceAuth.GceInstance {
   351  			logging.Errorf(ctx, "Bad GCE VM auth: want %s@%s, got %s@%s",
   352  				m.GceAuth.GceInstance, m.GceAuth.GceProject,
   353  				gceInfo.Instance, gceInfo.Project,
   354  			)
   355  			return errors.Reason("wrong GCE VM token: %s@%s", gceInfo.Instance, gceInfo.Project).Err()
   356  		}
   357  
   358  	case *internalspb.PollState_ServiceAccountAuth_:
   359  		peerID := auth.GetState(ctx).PeerIdentity()
   360  		if peerID.Kind() != identity.User {
   361  			return errors.Reason("expecting service account credentials").Err()
   362  		}
   363  		if peerID.Email() != m.ServiceAccountAuth.ServiceAccount {
   364  			logging.Errorf(ctx, "Bad service account auth: want %s, got %s",
   365  				m.ServiceAccountAuth.ServiceAccount,
   366  				peerID.Email(),
   367  			)
   368  			return errors.Reason("wrong service account: %s", peerID.Email()).Err()
   369  		}
   370  
   371  	case *internalspb.PollState_LuciMachineTokenAuth:
   372  		tokInfo := machine.GetMachineTokenInfo(ctx)
   373  		if tokInfo == nil {
   374  			return errors.Reason("expecting LUCI machine token auth").Err()
   375  		}
   376  		if tokInfo.FQDN != m.LuciMachineTokenAuth.MachineFqdn {
   377  			logging.Errorf(ctx, "Bad LUCI machine token FQDN: want %s, got %s",
   378  				m.LuciMachineTokenAuth.MachineFqdn,
   379  				tokInfo.FQDN,
   380  			)
   381  			return errors.Reason("wrong FQDN in the LUCI machine token: %s", tokInfo.FQDN).Err()
   382  		}
   383  
   384  	case *internalspb.PollState_IpAllowlistAuth:
   385  		// The actual check is below. Here we just verify the PollState token is
   386  		// consistent.
   387  		if pollState.IpAllowlist == "" {
   388  			return errors.Reason("bad poll token: using IP allowlist auth without an IP allowlist").Err()
   389  		}
   390  
   391  	default:
   392  		return errors.Reason("unrecognized auth method in the poll token: %v", pollState.AuthMethod).Err()
   393  	}
   394  
   395  	// Verify the bot is in the required IP allowlist (if any).
   396  	if pollState.IpAllowlist != "" {
   397  		switch yes, err := auth.IsAllowedIP(ctx, pollState.IpAllowlist); {
   398  		case err != nil:
   399  			return errors.Annotate(err, "IP allowlist check failed").Tag(transient.Tag).Err()
   400  		case !yes:
   401  			return errors.Reason("bot IP %s is not in the allowlist", auth.GetState(ctx).PeerIP()).Err()
   402  		}
   403  	}
   404  
   405  	return nil
   406  }
   407  
   408  // extractBotID extracts the bot ID from PollState.
   409  //
   410  // Returns "" if it is not present.
   411  func extractBotID(s *internalspb.PollState) string {
   412  	for _, dim := range s.EnforcedDimensions {
   413  		if dim.Key == "id" {
   414  			if len(dim.Values) > 0 {
   415  				return dim.Values[0]
   416  			}
   417  			return ""
   418  		}
   419  	}
   420  	return ""
   421  }
   422  
   423  // strSliceEq is true if two string slices are equal.
   424  func strSliceEq(a, b []string) bool {
   425  	if len(a) != len(b) {
   426  		return false
   427  	}
   428  	for i := range a {
   429  		if a[i] != b[i] {
   430  			return false
   431  		}
   432  	}
   433  	return true
   434  }