go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/swarming/server/botsrv/botsrv.go (about) 1 // Copyright 2022 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package botsrv knows how to authenticate calls from Swarming RBE bots. 16 // 17 // It checks PollState/BotSession tokens and bot credentials. 18 package botsrv 19 20 import ( 21 "context" 22 "encoding/json" 23 "fmt" 24 "io" 25 "net/http" 26 "strings" 27 28 "google.golang.org/grpc/codes" 29 "google.golang.org/grpc/status" 30 "google.golang.org/protobuf/encoding/prototext" 31 "google.golang.org/protobuf/proto" 32 33 "go.chromium.org/luci/auth/identity" 34 "go.chromium.org/luci/common/clock" 35 "go.chromium.org/luci/common/errors" 36 "go.chromium.org/luci/common/logging" 37 "go.chromium.org/luci/common/retry/transient" 38 "go.chromium.org/luci/grpc/grpcutil" 39 "go.chromium.org/luci/server/auth" 40 "go.chromium.org/luci/server/auth/openid" 41 "go.chromium.org/luci/server/router" 42 "go.chromium.org/luci/tokenserver/auth/machine" 43 44 internalspb "go.chromium.org/luci/swarming/proto/internals" 45 "go.chromium.org/luci/swarming/server/hmactoken" 46 ) 47 48 // RequestBody should be implemented by a JSON-serializable struct representing 49 // format of some particular request. 50 type RequestBody interface { 51 ExtractPollToken() []byte // the poll token, if present 52 ExtractSessionToken() []byte // the session token, if present 53 ExtractDimensions() map[string][]string // dimensions reported by the bot, if present 54 ExtractDebugRequest() any // serialized as JSON and logged on errors 55 } 56 57 // Request is extracted from an authenticated request from a bot. 58 type Request struct { 59 BotID string // validated bot ID 60 SessionID string // validated RBE bot session ID, if present 61 SessionTokenExpired bool // true if the request has expired session token 62 PollState *internalspb.PollState // validated poll state 63 Dimensions map[string][]string // validated dimensions 64 } 65 66 // Response is serialized as JSON and sent to the bot. 67 type Response any 68 69 // Handler handles an authenticated request from a bot. 70 // 71 // It takes a raw deserialized request body and all authenticated data extracted 72 // from it. 73 // 74 // It returns a response that will be serialized and sent to the bot as JSON or 75 // a gRPC error code that will be converted into an HTTP error. 76 type Handler[B any] func(ctx context.Context, body *B, req *Request) (Response, error) 77 78 // Server knows how to authenticate bot requests and route them to handlers. 79 type Server struct { 80 router *router.Router 81 middlewares router.MiddlewareChain 82 hmacSecret *hmactoken.Secret 83 } 84 85 // New constructs new Server. 86 func New(ctx context.Context, r *router.Router, projectID string, hmacSecret *hmactoken.Secret) *Server { 87 gaeAppDomain := fmt.Sprintf("%s.appspot.com", projectID) 88 return &Server{ 89 router: r, 90 middlewares: router.MiddlewareChain{ 91 // All supported bot authentication schemes. The first matching one wins. 92 auth.Authenticate( 93 // This checks "X-Luci-Gce-Vm-Token" header if present. The token 94 // audience should be `[https://][<prefix>-dot-]app.appspot.com`. 95 &openid.GoogleComputeAuthMethod{ 96 Header: "X-Luci-Gce-Vm-Token", 97 AudienceCheck: func(_ context.Context, _ auth.RequestMetadata, aud string) (bool, error) { 98 aud = strings.TrimPrefix(aud, "https://") 99 return aud == gaeAppDomain || strings.HasSuffix(aud, "-dot-"+gaeAppDomain), nil 100 }, 101 }, 102 // This checks "X-Luci-Machine-Token" header if present. 103 &machine.MachineTokenAuthMethod{}, 104 // This checks "Authorization" header if present. 105 &auth.GoogleOAuth2Method{ 106 Scopes: []string{"https://www.googleapis.com/auth/userinfo.email"}, 107 }, 108 ), 109 }, 110 hmacSecret: hmacSecret, 111 } 112 } 113 114 // RequestBodyConstraint is needed to make Go generics type checker happy. 115 type RequestBodyConstraint[B any] interface { 116 RequestBody 117 *B 118 } 119 120 // InstallHandler installs a bot request handler at the given route. 121 func InstallHandler[B any, RB RequestBodyConstraint[B]](s *Server, route string, h Handler[B]) { 122 s.router.POST(route, s.middlewares, func(c *router.Context) { 123 ctx := c.Request.Context() 124 req := c.Request 125 wrt := c.Writer 126 127 // Deserialized request body. 128 var body *B 129 130 // Deserialized and validated tokens in the request. 131 var pollTokenState *internalspb.PollState 132 var sessionState *internalspb.BotSession 133 134 // This is either pollTokenState or the poll state inside sessionState, 135 // depending on which token is non-expired. Populated below. 136 var pollState *internalspb.PollState 137 138 // writeErr logs a gRPC error and writes it to the HTTP response. 139 writeErr := func(err error) { 140 // Log request details to help in debugging errors. 141 logging.Infof(ctx, "Bot IP: %s", auth.GetState(ctx).PeerIP()) 142 logging.Infof(ctx, "Authenticated: %s", auth.GetState(ctx).PeerIdentity()) 143 if pollState != nil { 144 logging.Infof(ctx, "Bot ID: %s", extractBotID(pollState)) 145 logging.Infof(ctx, "Poll token ID: %s", pollState.Id) 146 logging.Infof(ctx, "RBE: %s", pollState.RbeInstance) 147 if pollState.DebugInfo != nil { 148 logging.Infof(ctx, "Poll token age: %s", clock.Now(ctx).Sub(pollState.DebugInfo.Created.AsTime())) 149 } 150 } 151 if sessionState != nil { 152 logging.Infof(ctx, "Session ID: %s", sessionState.RbeBotSessionId) 153 } 154 if body != nil { 155 blob, _ := json.MarshalIndent(RB(body).ExtractDebugRequest(), "", " ") 156 logging.Infof(ctx, "Request body:\n%s", blob) 157 } 158 159 // Log the actual error. 160 err = grpcutil.GRPCifyAndLogErr(ctx, err) 161 statusCode := status.Code(err) 162 httpCode := grpcutil.CodeStatus(statusCode) 163 if statusCode == codes.Unavailable { 164 // UNAVAILABLE seems to happen a lot, but in bursts (probably when the 165 // RBE scheduler restarts). Log it at the warning severity to make other 166 // errors more noticeable. 167 logging.Warningf(ctx, "HTTP %d: %s", httpCode, err) 168 } else { 169 logging.Errorf(ctx, "HTTP %d: %s", httpCode, err) 170 } 171 172 http.Error(wrt, err.Error(), httpCode) 173 } 174 175 // Deserialize JSON request body. 176 if ct := req.Header.Get("Content-Type"); strings.ToLower(ct) != "application/json; charset=utf-8" { 177 writeErr(status.Errorf(codes.InvalidArgument, "bad content type %q", ct)) 178 return 179 } 180 raw, err := io.ReadAll(req.Body) 181 if err != nil { 182 writeErr(status.Errorf(codes.Internal, "error reading request body: %s", err)) 183 return 184 } 185 body = new(B) 186 if err := json.Unmarshal(raw, body); err != nil { 187 logging.Warningf(ctx, "Unrecognized request:\n%s", raw) 188 writeErr(status.Errorf(codes.InvalidArgument, "failed to deserialized the request: %s", err)) 189 return 190 } 191 192 // To authenticate the bot we need either a non-expired poll token, a 193 // non-expired session token or both (in which case the poll token is 194 // preferred, since it should be more recently produced in this case). If we 195 // have a poll token, we validate it to directly get PollState. If we have 196 // a session token, we validate it and grab PollState from within it. This 197 // PollState is then used to check bot credentials. 198 // 199 // This scheme is necessary because poll tokens can be produced only by 200 // Python Swarming server when bot calls "/bot/poll" endpoint. When the bot 201 // is running a task, it isn't polling Python Swarming server and its poll 202 // token expires. For that reason when running a task (or making other 203 // post-task calls that happen before the next poll), we use the session 204 // token instead, which has the most recently validated PollState stored in 205 // it in a "frozen" state. 206 // 207 // When the bot is polling for tasks, it sends both poll token and session 208 // token to us, which allows us to put up-to-date PollState into the 209 // session token. This happens in UpdateBotSession handler. 210 211 // If have a poll token, validate and deserialize it. 212 if pollToken := RB(body).ExtractPollToken(); len(pollToken) != 0 { 213 pollTokenState = &internalspb.PollState{} 214 if err := s.hmacSecret.ValidateToken(pollToken, pollTokenState); err != nil { 215 writeErr(status.Errorf(codes.Unauthenticated, "failed to verify poll token: %s", err)) 216 return 217 } 218 if exp := clock.Now(ctx).Sub(pollTokenState.Expiry.AsTime()); exp > 0 { 219 logging.Warningf(ctx, "Ignoring poll token (expired %s ago):\n%s", exp, prettyProto(pollTokenState)) 220 pollTokenState = nil 221 } 222 } 223 // If have a session token, validate and deserialize it as well. 224 sessionTokenExpired := false 225 if sessionToken := RB(body).ExtractSessionToken(); len(sessionToken) != 0 { 226 sessionState = &internalspb.BotSession{} 227 if err := s.hmacSecret.ValidateToken(sessionToken, sessionState); err != nil { 228 writeErr(status.Errorf(codes.Unauthenticated, "failed to verify session token: %s", err)) 229 return 230 } 231 if exp := clock.Now(ctx).Sub(sessionState.Expiry.AsTime()); exp > 0 { 232 logging.Warningf(ctx, "Ignoring session token (expired %s ago):\n%s", exp, prettyProto(sessionState)) 233 sessionState = nil 234 sessionTokenExpired = true 235 } 236 } 237 238 // Need at least one valid and fresh token. 239 if pollTokenState == nil && sessionState == nil { 240 writeErr(status.Errorf(codes.Unauthenticated, "no valid poll or state token")) 241 return 242 } 243 244 // Prefer the state from the poll token. It is fresher. Fallback to the 245 // state stored in the session token if there's no poll token or it has 246 // expired. 247 pollState = pollTokenState 248 if pollState == nil { 249 pollState = sessionState.GetPollState() 250 if pollState == nil { 251 writeErr(status.Errorf(codes.Unauthenticated, "no poll state available")) 252 return 253 } 254 } 255 256 // Extract bot ID from the validated PollToken. 257 botID := extractBotID(pollState) 258 if botID == "" { 259 writeErr(status.Errorf(codes.InvalidArgument, "no bot ID")) 260 return 261 } 262 // Session ID must be present if there's a session token. 263 if sessionState != nil && sessionState.RbeBotSessionId == "" { 264 writeErr(status.Errorf(codes.InvalidArgument, "no session ID")) 265 return 266 } 267 268 // Verify bot credentials match what's recorded in the validated poll state. 269 if err := checkCredentials(ctx, pollState); err != nil { 270 if transient.Tag.In(err) { 271 writeErr(status.Errorf(codes.Internal, "transient error checking bot credentials: %s", err)) 272 } else { 273 writeErr(status.Errorf(codes.Unauthenticated, "bad bot credentials: %s", err)) 274 } 275 return 276 } 277 278 // Apply verified state stored in PollState on top of whatever was reported 279 // by the bot. Normally functioning bots should report the same values as 280 // stored in the token. 281 dims := RB(body).ExtractDimensions() 282 for _, dim := range pollState.EnforcedDimensions { 283 reported := dims[dim.Key] 284 if !strSliceEq(reported, dim.Values) { 285 logging.Errorf(ctx, "Dimension %q mismatch: reported %v, expecting %v", 286 dim.Key, reported, dim.Values, 287 ) 288 dims[dim.Key] = dim.Values 289 } 290 } 291 292 // There must be `pool` dimension with at least one value (perhaps more). 293 if len(dims["pool"]) == 0 { 294 writeErr(status.Errorf(codes.InvalidArgument, "no pool dimension")) 295 return 296 } 297 298 // The request is valid, dispatch it to the handler. 299 resp, err := h(ctx, body, &Request{ 300 BotID: botID, 301 SessionID: sessionState.GetRbeBotSessionId(), 302 SessionTokenExpired: sessionTokenExpired, 303 PollState: pollState, 304 Dimensions: dims, 305 }) 306 if err != nil { 307 writeErr(err) 308 return 309 } 310 311 // Success! Write back the response. 312 wrt.Header().Set("Content-Type", "application/json; charset=utf-8") 313 var werr error 314 if resp == nil { 315 _, werr = wrt.Write([]byte("{\"ok\": true}\n")) 316 } else { 317 werr = json.NewEncoder(wrt).Encode(resp) 318 } 319 if werr != nil { 320 logging.Errorf(ctx, "Error writing the response: %s", werr) 321 } 322 }) 323 } 324 325 // prettyProto formats a proto message for logs. 326 func prettyProto(msg proto.Message) string { 327 blob, err := prototext.MarshalOptions{ 328 Multiline: true, 329 Indent: " ", 330 }.Marshal(msg) 331 if err != nil { 332 return fmt.Sprintf("<error: %s>", err) 333 } 334 return string(blob) 335 } 336 337 // checkCredentials checks the bot credentials in the context match what is 338 // required by the PollState. 339 // 340 // It ensures the Go portion of the Swarming server authenticates the bot in 341 // the exact same way the Python portion did (since the Python portion produced 342 // the PollState after it authenticated the bot). 343 func checkCredentials(ctx context.Context, pollState *internalspb.PollState) error { 344 switch m := pollState.AuthMethod.(type) { 345 case *internalspb.PollState_GceAuth: 346 gceInfo := openid.GetGoogleComputeTokenInfo(ctx) 347 if gceInfo == nil { 348 return errors.Reason("expecting GCE VM token auth").Err() 349 } 350 if gceInfo.Project != m.GceAuth.GceProject || gceInfo.Instance != m.GceAuth.GceInstance { 351 logging.Errorf(ctx, "Bad GCE VM auth: want %s@%s, got %s@%s", 352 m.GceAuth.GceInstance, m.GceAuth.GceProject, 353 gceInfo.Instance, gceInfo.Project, 354 ) 355 return errors.Reason("wrong GCE VM token: %s@%s", gceInfo.Instance, gceInfo.Project).Err() 356 } 357 358 case *internalspb.PollState_ServiceAccountAuth_: 359 peerID := auth.GetState(ctx).PeerIdentity() 360 if peerID.Kind() != identity.User { 361 return errors.Reason("expecting service account credentials").Err() 362 } 363 if peerID.Email() != m.ServiceAccountAuth.ServiceAccount { 364 logging.Errorf(ctx, "Bad service account auth: want %s, got %s", 365 m.ServiceAccountAuth.ServiceAccount, 366 peerID.Email(), 367 ) 368 return errors.Reason("wrong service account: %s", peerID.Email()).Err() 369 } 370 371 case *internalspb.PollState_LuciMachineTokenAuth: 372 tokInfo := machine.GetMachineTokenInfo(ctx) 373 if tokInfo == nil { 374 return errors.Reason("expecting LUCI machine token auth").Err() 375 } 376 if tokInfo.FQDN != m.LuciMachineTokenAuth.MachineFqdn { 377 logging.Errorf(ctx, "Bad LUCI machine token FQDN: want %s, got %s", 378 m.LuciMachineTokenAuth.MachineFqdn, 379 tokInfo.FQDN, 380 ) 381 return errors.Reason("wrong FQDN in the LUCI machine token: %s", tokInfo.FQDN).Err() 382 } 383 384 case *internalspb.PollState_IpAllowlistAuth: 385 // The actual check is below. Here we just verify the PollState token is 386 // consistent. 387 if pollState.IpAllowlist == "" { 388 return errors.Reason("bad poll token: using IP allowlist auth without an IP allowlist").Err() 389 } 390 391 default: 392 return errors.Reason("unrecognized auth method in the poll token: %v", pollState.AuthMethod).Err() 393 } 394 395 // Verify the bot is in the required IP allowlist (if any). 396 if pollState.IpAllowlist != "" { 397 switch yes, err := auth.IsAllowedIP(ctx, pollState.IpAllowlist); { 398 case err != nil: 399 return errors.Annotate(err, "IP allowlist check failed").Tag(transient.Tag).Err() 400 case !yes: 401 return errors.Reason("bot IP %s is not in the allowlist", auth.GetState(ctx).PeerIP()).Err() 402 } 403 } 404 405 return nil 406 } 407 408 // extractBotID extracts the bot ID from PollState. 409 // 410 // Returns "" if it is not present. 411 func extractBotID(s *internalspb.PollState) string { 412 for _, dim := range s.EnforcedDimensions { 413 if dim.Key == "id" { 414 if len(dim.Values) > 0 { 415 return dim.Values[0] 416 } 417 return "" 418 } 419 } 420 return "" 421 } 422 423 // strSliceEq is true if two string slices are equal. 424 func strSliceEq(a, b []string) bool { 425 if len(a) != len(b) { 426 return false 427 } 428 for i := range a { 429 if a[i] != b[i] { 430 return false 431 } 432 } 433 return true 434 }