go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/auth/integration/localauth/server.go (about) 1 // Copyright 2017 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package localauth 16 17 import ( 18 "context" 19 "crypto/subtle" 20 "encoding/json" 21 "fmt" 22 "io" 23 "mime" 24 "net" 25 "net/http" 26 "regexp" 27 "sort" 28 "sync" 29 "time" 30 31 "golang.org/x/oauth2" 32 33 "go.chromium.org/luci/auth" 34 "go.chromium.org/luci/auth/integration/localauth/rpcs" 35 "go.chromium.org/luci/common/data/rand/cryptorand" 36 "go.chromium.org/luci/common/data/stringset" 37 "go.chromium.org/luci/common/errors" 38 "go.chromium.org/luci/common/logging" 39 "go.chromium.org/luci/common/retry/transient" 40 "go.chromium.org/luci/common/runtime/paniccatcher" 41 "go.chromium.org/luci/lucictx" 42 43 "go.chromium.org/luci/auth/integration/internal/localsrv" 44 ) 45 46 // TokenGenerator produces access or ID tokens. 47 // 48 // The canonical implementation is &auth.TokenGenerator{}. 49 type TokenGenerator interface { 50 // GenerateOAuthToken returns an access token for a combination of scopes. 51 // 52 // It is called for each request to the local auth server. It may be called 53 // concurrently from multiple goroutines and must implement its own caching 54 // and synchronization if necessary. 55 // 56 // It is expected that the returned token lives for at least given 'lifetime' 57 // duration (which is typically on order of minutes), but it may live longer. 58 // Clients may cache the returned token for the duration of its lifetime. 59 // 60 // May return transient errors (in transient.Tag.In(err) returning true 61 // sense). Such errors result in HTTP 500 responses. This is appropriate for 62 // non-fatal errors. Clients may immediately retry requests on such errors. 63 // 64 // Any non-transient error is considered fatal and results in an RPC-level 65 // error response ({"error": ...}). Clients must treat such responses as fatal 66 // and don't retry requests. 67 // 68 // If the error implements ErrorWithCode interface, the error code returned to 69 // clients will be grabbed from the error object, otherwise the error code is 70 // set to -1. 71 GenerateOAuthToken(ctx context.Context, scopes []string, lifetime time.Duration) (*oauth2.Token, error) 72 73 // GenerateIDToken returns an ID token with the given audience in `aud` claim. 74 // 75 // All details specified in GenerateOAuthToken doc also apply to 76 // GenerateIDToken. 77 GenerateIDToken(ctx context.Context, audience string, lifetime time.Duration) (*oauth2.Token, error) 78 79 // GetEmail returns an email associated with all tokens produced by this 80 // generator or auth.ErrNoEmail if it's not available. 81 // 82 // Any other error will bubble up through Server.Start. 83 GetEmail() (string, error) 84 } 85 86 // ErrorWithCode is a fatal error that also has a numeric code. 87 // 88 // May be returned by TokenGenerator to trigger a response with some specific 89 // error code. 90 type ErrorWithCode interface { 91 error 92 93 // Code returns a code to put into RPC response alongside the error message. 94 Code() int 95 } 96 97 // Server runs a local RPC server that hands out access tokens. 98 // 99 // Processes that need a token can discover location of this server by looking 100 // at "local_auth" section of LUCI_CONTEXT. 101 type Server struct { 102 // TokenGenerators produce access tokens for given account IDs. 103 TokenGenerators map[string]TokenGenerator 104 105 // DefaultAccountID is account ID subprocesses should pick by default. 106 // 107 // It is put into "local_auth" section of LUCI_CONTEXT. If empty string, 108 // subprocesses won't attempt to use any account by default (they still can 109 // pick some non-default account though). 110 DefaultAccountID string 111 112 // Port is a local TCP port to bind to or 0 to allow the OS to pick one. 113 Port int 114 115 srv localsrv.Server 116 117 testingServeHook func() // called right before serving 118 } 119 120 // Start launches background goroutine with the serving loop. 121 // 122 // The provided context is used as base context for request handlers and for 123 // logging. 124 // 125 // Returns a copy of lucictx.LocalAuth structure that specifies how to contact 126 // the server. It should be put into "local_auth" section of LUCI_CONTEXT where 127 // clients can discover it. 128 // 129 // The server must be eventually stopped with Stop(). 130 func (s *Server) Start(ctx context.Context) (*lucictx.LocalAuth, error) { 131 la, err := s.initLocalAuth(ctx) 132 if err != nil { 133 return nil, errors.Annotate(err, "failed to initialize LocalAuth").Err() 134 } 135 136 addr, err := s.srv.Start(ctx, "local_auth", s.Port, func(c context.Context, l net.Listener, wg *sync.WaitGroup) error { 137 return s.serve(c, l, wg, la.Secret) 138 }) 139 if err != nil { 140 return nil, errors.Annotate(err, "failed to start the local server").Err() 141 } 142 143 la.RpcPort = uint32(addr.Port) 144 return la, nil 145 } 146 147 // Stop closes the listening socket, notifies pending requests to abort and 148 // stops the internal serving goroutine. 149 // 150 // Safe to call multiple times. Once stopped, the server cannot be started again 151 // (make a new instance of Server instead). 152 // 153 // Uses the given context for the deadline when waiting for the serving loop 154 // to stop. 155 func (s *Server) Stop(ctx context.Context) error { 156 return s.srv.Stop(ctx) 157 } 158 159 // initLocalAuth generates new LocalAuth struct with RPC port blank. 160 func (s *Server) initLocalAuth(ctx context.Context) (*lucictx.LocalAuth, error) { 161 // Build a sorted list of LocalAuthAccount to put into the context, grab 162 // emails from the generators. 163 ids := make([]string, 0, len(s.TokenGenerators)) 164 for id := range s.TokenGenerators { 165 ids = append(ids, id) 166 } 167 sort.Strings(ids) 168 accounts := make([]*lucictx.LocalAuthAccount, len(ids)) 169 for i, id := range ids { 170 email, err := s.TokenGenerators[id].GetEmail() 171 switch { 172 case err == auth.ErrNoEmail: 173 email = "-" 174 case err != nil: 175 return nil, errors.Annotate(err, "could not grab email of account %q", id).Err() 176 } 177 accounts[i] = &lucictx.LocalAuthAccount{Id: id, Email: email} 178 } 179 180 secret := make([]byte, 48) 181 if _, err := cryptorand.Read(ctx, secret); err != nil { 182 return nil, err 183 } 184 185 return &lucictx.LocalAuth{ 186 Secret: secret, 187 Accounts: accounts, 188 DefaultAccountId: s.DefaultAccountID, 189 }, nil 190 } 191 192 // serve runs the serving loop. 193 func (s *Server) serve(ctx context.Context, l net.Listener, wg *sync.WaitGroup, secret []byte) error { 194 if s.testingServeHook != nil { 195 s.testingServeHook() 196 } 197 srv := http.Server{ 198 Handler: &protocolHandler{ 199 ctx: ctx, 200 wg: wg, 201 secret: secret, 202 tokens: s.TokenGenerators, 203 }, 204 } 205 return srv.Serve(l) 206 } 207 208 //////////////////////////////////////////////////////////////////////////////// 209 // Protocol implementation. 210 211 // methodRe defines an URL of RPC method handler. 212 var methodRe = regexp.MustCompile(`^/rpc/LuciLocalAuthService\.([a-zA-Z0-9_]+)$`) 213 214 // minTokenLifetime is a lifetime of tokens requested through TokenGenerator. 215 // 216 // Must be larger than 'minAcceptedLifetime' in the auth package, or weird 217 // things may happen if local_auth server is used as a basis for some 218 // auth.Authenticator. 219 const minTokenLifetime = 3 * time.Minute 220 221 // handle is called by http.Server in a separate goroutine to handle a request. 222 // 223 // It implements the server side of local_auth RPC protocol: 224 // - Each request is POST to /rpc/LuciLocalAuthService.<Method> 225 // - Request content type is "application/json; ...". 226 // - The sender must set Content-Length header. 227 // - Response content type is also "application/json". 228 // - The server sets Content-Length header in the response. 229 // - Protocol-level errors have non-200 HTTP status code. 230 // - Logic errors have 200 HTTP status code and error is communicated in 231 // the response body. 232 // 233 // Supported methods are: 234 // 235 // GetOAuthToken: 236 // 237 // Request body: 238 // { 239 // "scopes": [<string scope1>, <string scope2>, ...], 240 // "secret": <string from LUCI_CONTEXT.local_auth.secret>, 241 // "account_id": <ID of some account from LUCI_CONTEXT.local_auth.accounts> 242 // } 243 // Response body: 244 // { 245 // "error_code": <int, on success not set or 0>, 246 // "error_message": <string, on success not set>, 247 // "access_token": <string with actual token (on success)>, 248 // "expiry": <int with unix timestamp in seconds (on success)> 249 // } 250 // 251 // GetIDToken: 252 // 253 // Request body: 254 // { 255 // "audience": <string>, 256 // "secret": <string from LUCI_CONTEXT.local_auth.secret>, 257 // "account_id": <ID of some account from LUCI_CONTEXT.local_auth.accounts> 258 // } 259 // Response body: 260 // { 261 // "error_code": <int, on success not set or 0>, 262 // "error_message": <string, on success not set>, 263 // "id_token": <string with actual token (on success)>, 264 // "expiry": <int with unix timestamp in seconds (on success)> 265 // } 266 // 267 // See also python counterpart of this code: 268 // https://chromium.googlesource.com/infra/luci/luci-py/+/HEAD/client/utils/auth_server.py 269 type protocolHandler struct { 270 ctx context.Context // the parent context 271 wg *sync.WaitGroup // used for graceful shutdown 272 secret []byte // expected "secret" value 273 tokens map[string]TokenGenerator // the actual producer of tokens (per account) 274 } 275 276 // protocolError triggers an HTTP reply with some non-200 status code. 277 type protocolError struct { 278 Status int // HTTP status to set 279 Message string // the message to put in the body 280 } 281 282 func (e *protocolError) Error() string { 283 return fmt.Sprintf("%s (HTTP %d)", e.Message, e.Status) 284 } 285 286 // ServeHTTP implements the protocol marshaling logic. 287 func (h *protocolHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { 288 h.wg.Add(1) 289 defer h.wg.Done() 290 291 defer paniccatcher.Catch(func(p *paniccatcher.Panic) { 292 logging.Fields{ 293 "panic.error": p.Reason, 294 }.Errorf(h.ctx, "Caught panic during handling of %q: %s\n%s", r.RequestURI, p.Reason, p.Stack) 295 http.Error(rw, "Internal Server Error. See logs.", http.StatusInternalServerError) 296 }) 297 298 logging.Debugf(h.ctx, "Handling %s %s", r.Method, r.RequestURI) 299 300 if r.Method != "POST" { 301 http.Error(rw, "Expecting POST", http.StatusMethodNotAllowed) 302 return 303 } 304 305 // Grab <method> from /rpc/LuciLocalAuthService.<method>. 306 matches := methodRe.FindStringSubmatch(r.RequestURI) 307 if len(matches) != 2 { 308 http.Error(rw, "Expecting /rpc/LuciLocalAuthService.<method>", http.StatusNotFound) 309 return 310 } 311 method := matches[1] 312 313 // The content type must be JSON, which is also the default. 314 if ct := r.Header.Get("Content-Type"); ct != "" { 315 baseType, _, err := mime.ParseMediaType(ct) 316 if err != nil { 317 http.Error(rw, fmt.Sprintf("Can't parse Content-Type: %s", err), http.StatusBadRequest) 318 return 319 } 320 if baseType != "application/json" { 321 http.Error(rw, "Expecting 'application/json' Content-Type", http.StatusBadRequest) 322 return 323 } 324 } 325 326 // The content length must be given and be small enough. 327 if r.ContentLength < 0 || r.ContentLength >= 64*1024 { 328 http.Error(rw, "Expecting 'Content-Length' header, <64Kb", http.StatusBadRequest) 329 return 330 } 331 332 // Slurp the body, it's easier to deal with []byte going forward. The body is 333 // tiny anyway. 334 request := make([]byte, r.ContentLength) 335 if _, err := io.ReadFull(r.Body, request); err != nil { 336 http.Error(rw, "Can't read the request body", http.StatusBadGateway) 337 return 338 } 339 340 // Route to the appropriate RPC handler. 341 response, err := h.routeToImpl(method, request) 342 343 // *protocolError are sent as HTTP errors. 344 if pErr, _ := err.(*protocolError); pErr != nil { 345 http.Error(rw, pErr.Message, pErr.Status) 346 return 347 } 348 349 // Transient errors are returned as HTTP 500 responses. 350 if transient.Tag.In(err) { 351 http.Error(rw, fmt.Sprintf("Transient error - %s", err), http.StatusInternalServerError) 352 return 353 } 354 355 // Fatal errors are returned as specially structured JSON responses with 356 // HTTP 200 code. Replace 'response' with it. 357 if err != nil { 358 fatalError := rpcs.BaseResponse{ 359 ErrorCode: -1, 360 ErrorMessage: err.Error(), 361 } 362 if withCode, ok := err.(ErrorWithCode); ok && withCode.Code() != 0 { 363 fatalError.ErrorCode = withCode.Code() 364 } 365 response = &fatalError 366 } 367 368 // Serialize the response to grab its length. 369 blob, err := json.Marshal(response) 370 if err != nil { 371 http.Error(rw, fmt.Sprintf("Failed to serialize the response - %s", err), http.StatusInternalServerError) 372 return 373 } 374 blob = append(blob, '\n') // for curl's sake 375 376 // Finally write the response. 377 rw.Header().Set("Content-Type", "application/json; charset=utf-8") 378 rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(blob))) 379 rw.WriteHeader(http.StatusOK) 380 if _, err := rw.Write(blob); err != nil { 381 logging.WithError(err).Warningf(h.ctx, "Failed to write the response") 382 } 383 } 384 385 // routeToImpl calls appropriate RPC method implementation. 386 func (h *protocolHandler) routeToImpl(method string, request []byte) (any, error) { 387 switch method { 388 case "GetOAuthToken": 389 req := &rpcs.GetOAuthTokenRequest{} 390 if err := unmarshalRequest(request, req); err != nil { 391 return nil, err 392 } 393 return h.handleGetOAuthToken(req) 394 case "GetIDToken": 395 req := &rpcs.GetIDTokenRequest{} 396 if err := unmarshalRequest(request, req); err != nil { 397 return nil, err 398 } 399 return h.handleGetIDToken(req) 400 default: 401 return nil, &protocolError{ 402 Status: http.StatusNotFound, 403 Message: fmt.Sprintf("Unknown RPC method %q", method), 404 } 405 } 406 } 407 408 // unmarshalRequest unmarshals JSON body of the request, handling errors. 409 func unmarshalRequest(blob []byte, req any) error { 410 if err := json.Unmarshal(blob, req); err != nil { 411 return &protocolError{ 412 Status: http.StatusBadRequest, 413 Message: fmt.Sprintf("Not JSON body - %s", err), 414 } 415 } 416 return nil 417 } 418 419 //////////////////////////////////////////////////////////////////////////////// 420 // RPC implementations. 421 422 // checkSecretAndAccount checks the secret string in the request and looks up 423 // the TokenGenerator based on the account ID in the request. 424 func (h *protocolHandler) checkSecretAndAccount(req *rpcs.BaseRequest) (TokenGenerator, error) { 425 if subtle.ConstantTimeCompare(h.secret, req.Secret) != 1 { 426 return nil, &protocolError{ 427 Status: 403, 428 Message: "Invalid secret.", 429 } 430 } 431 generator := h.tokens[req.AccountID] 432 if generator == nil { 433 return nil, &protocolError{ 434 Status: 404, 435 Message: fmt.Sprintf("Unrecognized account ID %q.", req.AccountID), 436 } 437 } 438 return generator, nil 439 } 440 441 func (h *protocolHandler) handleGetOAuthToken(req *rpcs.GetOAuthTokenRequest) (*rpcs.GetOAuthTokenResponse, error) { 442 if err := req.Validate(); err != nil { 443 return nil, &protocolError{ 444 Status: 400, 445 Message: fmt.Sprintf("Bad request: %s.", err.Error()), 446 } 447 } 448 generator, err := h.checkSecretAndAccount(&req.BaseRequest) 449 if err != nil { 450 return nil, err 451 } 452 453 // Dedup and sort scopes. 454 scopes := stringset.New(len(req.Scopes)) 455 for _, s := range req.Scopes { 456 scopes.Add(s) 457 } 458 sortedScopes := scopes.ToSortedSlice() 459 460 // Note: this may produce ErrorWithCode. 461 tok, err := generator.GenerateOAuthToken(h.ctx, sortedScopes, minTokenLifetime) 462 if err != nil { 463 return nil, err 464 } 465 return &rpcs.GetOAuthTokenResponse{ 466 AccessToken: tok.AccessToken, 467 Expiry: tok.Expiry.Unix(), 468 }, nil 469 } 470 471 func (h *protocolHandler) handleGetIDToken(req *rpcs.GetIDTokenRequest) (*rpcs.GetIDTokenResponse, error) { 472 if err := req.Validate(); err != nil { 473 return nil, &protocolError{ 474 Status: 400, 475 Message: fmt.Sprintf("Bad request: %s.", err.Error()), 476 } 477 } 478 generator, err := h.checkSecretAndAccount(&req.BaseRequest) 479 if err != nil { 480 return nil, err 481 } 482 483 // Note: this may produce ErrorWithCode. 484 tok, err := generator.GenerateIDToken(h.ctx, req.Audience, minTokenLifetime) 485 if err != nil { 486 return nil, err 487 } 488 return &rpcs.GetIDTokenResponse{ 489 IDToken: tok.AccessToken, // this is actually an ID token 490 Expiry: tok.Expiry.Unix(), 491 }, nil 492 }