go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/auth/integration/localauth/server.go (about)

     1  // Copyright 2017 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package localauth
    16  
    17  import (
    18  	"context"
    19  	"crypto/subtle"
    20  	"encoding/json"
    21  	"fmt"
    22  	"io"
    23  	"mime"
    24  	"net"
    25  	"net/http"
    26  	"regexp"
    27  	"sort"
    28  	"sync"
    29  	"time"
    30  
    31  	"golang.org/x/oauth2"
    32  
    33  	"go.chromium.org/luci/auth"
    34  	"go.chromium.org/luci/auth/integration/localauth/rpcs"
    35  	"go.chromium.org/luci/common/data/rand/cryptorand"
    36  	"go.chromium.org/luci/common/data/stringset"
    37  	"go.chromium.org/luci/common/errors"
    38  	"go.chromium.org/luci/common/logging"
    39  	"go.chromium.org/luci/common/retry/transient"
    40  	"go.chromium.org/luci/common/runtime/paniccatcher"
    41  	"go.chromium.org/luci/lucictx"
    42  
    43  	"go.chromium.org/luci/auth/integration/internal/localsrv"
    44  )
    45  
    46  // TokenGenerator produces access or ID tokens.
    47  //
    48  // The canonical implementation is &auth.TokenGenerator{}.
    49  type TokenGenerator interface {
    50  	// GenerateOAuthToken returns an access token for a combination of scopes.
    51  	//
    52  	// It is called for each request to the local auth server. It may be called
    53  	// concurrently from multiple goroutines and must implement its own caching
    54  	// and synchronization if necessary.
    55  	//
    56  	// It is expected that the returned token lives for at least given 'lifetime'
    57  	// duration (which is typically on order of minutes), but it may live longer.
    58  	// Clients may cache the returned token for the duration of its lifetime.
    59  	//
    60  	// May return transient errors (in transient.Tag.In(err) returning true
    61  	// sense). Such errors result in HTTP 500 responses. This is appropriate for
    62  	// non-fatal errors. Clients may immediately retry requests on such errors.
    63  	//
    64  	// Any non-transient error is considered fatal and results in an RPC-level
    65  	// error response ({"error": ...}). Clients must treat such responses as fatal
    66  	// and don't retry requests.
    67  	//
    68  	// If the error implements ErrorWithCode interface, the error code returned to
    69  	// clients will be grabbed from the error object, otherwise the error code is
    70  	// set to -1.
    71  	GenerateOAuthToken(ctx context.Context, scopes []string, lifetime time.Duration) (*oauth2.Token, error)
    72  
    73  	// GenerateIDToken returns an ID token with the given audience in `aud` claim.
    74  	//
    75  	// All details specified in GenerateOAuthToken doc also apply to
    76  	// GenerateIDToken.
    77  	GenerateIDToken(ctx context.Context, audience string, lifetime time.Duration) (*oauth2.Token, error)
    78  
    79  	// GetEmail returns an email associated with all tokens produced by this
    80  	// generator or auth.ErrNoEmail if it's not available.
    81  	//
    82  	// Any other error will bubble up through Server.Start.
    83  	GetEmail() (string, error)
    84  }
    85  
    86  // ErrorWithCode is a fatal error that also has a numeric code.
    87  //
    88  // May be returned by TokenGenerator to trigger a response with some specific
    89  // error code.
    90  type ErrorWithCode interface {
    91  	error
    92  
    93  	// Code returns a code to put into RPC response alongside the error message.
    94  	Code() int
    95  }
    96  
    97  // Server runs a local RPC server that hands out access tokens.
    98  //
    99  // Processes that need a token can discover location of this server by looking
   100  // at "local_auth" section of LUCI_CONTEXT.
   101  type Server struct {
   102  	// TokenGenerators produce access tokens for given account IDs.
   103  	TokenGenerators map[string]TokenGenerator
   104  
   105  	// DefaultAccountID is account ID subprocesses should pick by default.
   106  	//
   107  	// It is put into "local_auth" section of LUCI_CONTEXT. If empty string,
   108  	// subprocesses won't attempt to use any account by default (they still can
   109  	// pick some non-default account though).
   110  	DefaultAccountID string
   111  
   112  	// Port is a local TCP port to bind to or 0 to allow the OS to pick one.
   113  	Port int
   114  
   115  	srv localsrv.Server
   116  
   117  	testingServeHook func() // called right before serving
   118  }
   119  
   120  // Start launches background goroutine with the serving loop.
   121  //
   122  // The provided context is used as base context for request handlers and for
   123  // logging.
   124  //
   125  // Returns a copy of lucictx.LocalAuth structure that specifies how to contact
   126  // the server. It should be put into "local_auth" section of LUCI_CONTEXT where
   127  // clients can discover it.
   128  //
   129  // The server must be eventually stopped with Stop().
   130  func (s *Server) Start(ctx context.Context) (*lucictx.LocalAuth, error) {
   131  	la, err := s.initLocalAuth(ctx)
   132  	if err != nil {
   133  		return nil, errors.Annotate(err, "failed to initialize LocalAuth").Err()
   134  	}
   135  
   136  	addr, err := s.srv.Start(ctx, "local_auth", s.Port, func(c context.Context, l net.Listener, wg *sync.WaitGroup) error {
   137  		return s.serve(c, l, wg, la.Secret)
   138  	})
   139  	if err != nil {
   140  		return nil, errors.Annotate(err, "failed to start the local server").Err()
   141  	}
   142  
   143  	la.RpcPort = uint32(addr.Port)
   144  	return la, nil
   145  }
   146  
   147  // Stop closes the listening socket, notifies pending requests to abort and
   148  // stops the internal serving goroutine.
   149  //
   150  // Safe to call multiple times. Once stopped, the server cannot be started again
   151  // (make a new instance of Server instead).
   152  //
   153  // Uses the given context for the deadline when waiting for the serving loop
   154  // to stop.
   155  func (s *Server) Stop(ctx context.Context) error {
   156  	return s.srv.Stop(ctx)
   157  }
   158  
   159  // initLocalAuth generates new LocalAuth struct with RPC port blank.
   160  func (s *Server) initLocalAuth(ctx context.Context) (*lucictx.LocalAuth, error) {
   161  	// Build a sorted list of LocalAuthAccount to put into the context, grab
   162  	// emails from the generators.
   163  	ids := make([]string, 0, len(s.TokenGenerators))
   164  	for id := range s.TokenGenerators {
   165  		ids = append(ids, id)
   166  	}
   167  	sort.Strings(ids)
   168  	accounts := make([]*lucictx.LocalAuthAccount, len(ids))
   169  	for i, id := range ids {
   170  		email, err := s.TokenGenerators[id].GetEmail()
   171  		switch {
   172  		case err == auth.ErrNoEmail:
   173  			email = "-"
   174  		case err != nil:
   175  			return nil, errors.Annotate(err, "could not grab email of account %q", id).Err()
   176  		}
   177  		accounts[i] = &lucictx.LocalAuthAccount{Id: id, Email: email}
   178  	}
   179  
   180  	secret := make([]byte, 48)
   181  	if _, err := cryptorand.Read(ctx, secret); err != nil {
   182  		return nil, err
   183  	}
   184  
   185  	return &lucictx.LocalAuth{
   186  		Secret:           secret,
   187  		Accounts:         accounts,
   188  		DefaultAccountId: s.DefaultAccountID,
   189  	}, nil
   190  }
   191  
   192  // serve runs the serving loop.
   193  func (s *Server) serve(ctx context.Context, l net.Listener, wg *sync.WaitGroup, secret []byte) error {
   194  	if s.testingServeHook != nil {
   195  		s.testingServeHook()
   196  	}
   197  	srv := http.Server{
   198  		Handler: &protocolHandler{
   199  			ctx:    ctx,
   200  			wg:     wg,
   201  			secret: secret,
   202  			tokens: s.TokenGenerators,
   203  		},
   204  	}
   205  	return srv.Serve(l)
   206  }
   207  
   208  ////////////////////////////////////////////////////////////////////////////////
   209  // Protocol implementation.
   210  
   211  // methodRe defines an URL of RPC method handler.
   212  var methodRe = regexp.MustCompile(`^/rpc/LuciLocalAuthService\.([a-zA-Z0-9_]+)$`)
   213  
   214  // minTokenLifetime is a lifetime of tokens requested through TokenGenerator.
   215  //
   216  // Must be larger than 'minAcceptedLifetime' in the auth package, or weird
   217  // things may happen if local_auth server is used as a basis for some
   218  // auth.Authenticator.
   219  const minTokenLifetime = 3 * time.Minute
   220  
   221  // handle is called by http.Server in a separate goroutine to handle a request.
   222  //
   223  // It implements the server side of local_auth RPC protocol:
   224  //   - Each request is POST to /rpc/LuciLocalAuthService.<Method>
   225  //   - Request content type is "application/json; ...".
   226  //   - The sender must set Content-Length header.
   227  //   - Response content type is also "application/json".
   228  //   - The server sets Content-Length header in the response.
   229  //   - Protocol-level errors have non-200 HTTP status code.
   230  //   - Logic errors have 200 HTTP status code and error is communicated in
   231  //     the response body.
   232  //
   233  // Supported methods are:
   234  //
   235  // GetOAuthToken:
   236  //
   237  //	Request body:
   238  //	{
   239  //	  "scopes": [<string scope1>, <string scope2>, ...],
   240  //	  "secret": <string from LUCI_CONTEXT.local_auth.secret>,
   241  //	  "account_id": <ID of some account from LUCI_CONTEXT.local_auth.accounts>
   242  //	}
   243  //	Response body:
   244  //	{
   245  //	  "error_code": <int, on success not set or 0>,
   246  //	  "error_message": <string, on success not set>,
   247  //	  "access_token": <string with actual token (on success)>,
   248  //	  "expiry": <int with unix timestamp in seconds (on success)>
   249  //	}
   250  //
   251  // GetIDToken:
   252  //
   253  //	Request body:
   254  //	{
   255  //	  "audience": <string>,
   256  //	  "secret": <string from LUCI_CONTEXT.local_auth.secret>,
   257  //	  "account_id": <ID of some account from LUCI_CONTEXT.local_auth.accounts>
   258  //	}
   259  //	Response body:
   260  //	{
   261  //	  "error_code": <int, on success not set or 0>,
   262  //	  "error_message": <string, on success not set>,
   263  //	  "id_token": <string with actual token (on success)>,
   264  //	  "expiry": <int with unix timestamp in seconds (on success)>
   265  //	}
   266  //
   267  // See also python counterpart of this code:
   268  // https://chromium.googlesource.com/infra/luci/luci-py/+/HEAD/client/utils/auth_server.py
   269  type protocolHandler struct {
   270  	ctx    context.Context           // the parent context
   271  	wg     *sync.WaitGroup           // used for graceful shutdown
   272  	secret []byte                    // expected "secret" value
   273  	tokens map[string]TokenGenerator // the actual producer of tokens (per account)
   274  }
   275  
   276  // protocolError triggers an HTTP reply with some non-200 status code.
   277  type protocolError struct {
   278  	Status  int    // HTTP status to set
   279  	Message string // the message to put in the body
   280  }
   281  
   282  func (e *protocolError) Error() string {
   283  	return fmt.Sprintf("%s (HTTP %d)", e.Message, e.Status)
   284  }
   285  
   286  // ServeHTTP implements the protocol marshaling logic.
   287  func (h *protocolHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
   288  	h.wg.Add(1)
   289  	defer h.wg.Done()
   290  
   291  	defer paniccatcher.Catch(func(p *paniccatcher.Panic) {
   292  		logging.Fields{
   293  			"panic.error": p.Reason,
   294  		}.Errorf(h.ctx, "Caught panic during handling of %q: %s\n%s", r.RequestURI, p.Reason, p.Stack)
   295  		http.Error(rw, "Internal Server Error. See logs.", http.StatusInternalServerError)
   296  	})
   297  
   298  	logging.Debugf(h.ctx, "Handling %s %s", r.Method, r.RequestURI)
   299  
   300  	if r.Method != "POST" {
   301  		http.Error(rw, "Expecting POST", http.StatusMethodNotAllowed)
   302  		return
   303  	}
   304  
   305  	// Grab <method> from /rpc/LuciLocalAuthService.<method>.
   306  	matches := methodRe.FindStringSubmatch(r.RequestURI)
   307  	if len(matches) != 2 {
   308  		http.Error(rw, "Expecting /rpc/LuciLocalAuthService.<method>", http.StatusNotFound)
   309  		return
   310  	}
   311  	method := matches[1]
   312  
   313  	// The content type must be JSON, which is also the default.
   314  	if ct := r.Header.Get("Content-Type"); ct != "" {
   315  		baseType, _, err := mime.ParseMediaType(ct)
   316  		if err != nil {
   317  			http.Error(rw, fmt.Sprintf("Can't parse Content-Type: %s", err), http.StatusBadRequest)
   318  			return
   319  		}
   320  		if baseType != "application/json" {
   321  			http.Error(rw, "Expecting 'application/json' Content-Type", http.StatusBadRequest)
   322  			return
   323  		}
   324  	}
   325  
   326  	// The content length must be given and be small enough.
   327  	if r.ContentLength < 0 || r.ContentLength >= 64*1024 {
   328  		http.Error(rw, "Expecting 'Content-Length' header, <64Kb", http.StatusBadRequest)
   329  		return
   330  	}
   331  
   332  	// Slurp the body, it's easier to deal with []byte going forward. The body is
   333  	// tiny anyway.
   334  	request := make([]byte, r.ContentLength)
   335  	if _, err := io.ReadFull(r.Body, request); err != nil {
   336  		http.Error(rw, "Can't read the request body", http.StatusBadGateway)
   337  		return
   338  	}
   339  
   340  	// Route to the appropriate RPC handler.
   341  	response, err := h.routeToImpl(method, request)
   342  
   343  	// *protocolError are sent as HTTP errors.
   344  	if pErr, _ := err.(*protocolError); pErr != nil {
   345  		http.Error(rw, pErr.Message, pErr.Status)
   346  		return
   347  	}
   348  
   349  	// Transient errors are returned as HTTP 500 responses.
   350  	if transient.Tag.In(err) {
   351  		http.Error(rw, fmt.Sprintf("Transient error - %s", err), http.StatusInternalServerError)
   352  		return
   353  	}
   354  
   355  	// Fatal errors are returned as specially structured JSON responses with
   356  	// HTTP 200 code. Replace 'response' with it.
   357  	if err != nil {
   358  		fatalError := rpcs.BaseResponse{
   359  			ErrorCode:    -1,
   360  			ErrorMessage: err.Error(),
   361  		}
   362  		if withCode, ok := err.(ErrorWithCode); ok && withCode.Code() != 0 {
   363  			fatalError.ErrorCode = withCode.Code()
   364  		}
   365  		response = &fatalError
   366  	}
   367  
   368  	// Serialize the response to grab its length.
   369  	blob, err := json.Marshal(response)
   370  	if err != nil {
   371  		http.Error(rw, fmt.Sprintf("Failed to serialize the response - %s", err), http.StatusInternalServerError)
   372  		return
   373  	}
   374  	blob = append(blob, '\n') // for curl's sake
   375  
   376  	// Finally write the response.
   377  	rw.Header().Set("Content-Type", "application/json; charset=utf-8")
   378  	rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(blob)))
   379  	rw.WriteHeader(http.StatusOK)
   380  	if _, err := rw.Write(blob); err != nil {
   381  		logging.WithError(err).Warningf(h.ctx, "Failed to write the response")
   382  	}
   383  }
   384  
   385  // routeToImpl calls appropriate RPC method implementation.
   386  func (h *protocolHandler) routeToImpl(method string, request []byte) (any, error) {
   387  	switch method {
   388  	case "GetOAuthToken":
   389  		req := &rpcs.GetOAuthTokenRequest{}
   390  		if err := unmarshalRequest(request, req); err != nil {
   391  			return nil, err
   392  		}
   393  		return h.handleGetOAuthToken(req)
   394  	case "GetIDToken":
   395  		req := &rpcs.GetIDTokenRequest{}
   396  		if err := unmarshalRequest(request, req); err != nil {
   397  			return nil, err
   398  		}
   399  		return h.handleGetIDToken(req)
   400  	default:
   401  		return nil, &protocolError{
   402  			Status:  http.StatusNotFound,
   403  			Message: fmt.Sprintf("Unknown RPC method %q", method),
   404  		}
   405  	}
   406  }
   407  
   408  // unmarshalRequest unmarshals JSON body of the request, handling errors.
   409  func unmarshalRequest(blob []byte, req any) error {
   410  	if err := json.Unmarshal(blob, req); err != nil {
   411  		return &protocolError{
   412  			Status:  http.StatusBadRequest,
   413  			Message: fmt.Sprintf("Not JSON body - %s", err),
   414  		}
   415  	}
   416  	return nil
   417  }
   418  
   419  ////////////////////////////////////////////////////////////////////////////////
   420  // RPC implementations.
   421  
   422  // checkSecretAndAccount checks the secret string in the request and looks up
   423  // the TokenGenerator based on the account ID in the request.
   424  func (h *protocolHandler) checkSecretAndAccount(req *rpcs.BaseRequest) (TokenGenerator, error) {
   425  	if subtle.ConstantTimeCompare(h.secret, req.Secret) != 1 {
   426  		return nil, &protocolError{
   427  			Status:  403,
   428  			Message: "Invalid secret.",
   429  		}
   430  	}
   431  	generator := h.tokens[req.AccountID]
   432  	if generator == nil {
   433  		return nil, &protocolError{
   434  			Status:  404,
   435  			Message: fmt.Sprintf("Unrecognized account ID %q.", req.AccountID),
   436  		}
   437  	}
   438  	return generator, nil
   439  }
   440  
   441  func (h *protocolHandler) handleGetOAuthToken(req *rpcs.GetOAuthTokenRequest) (*rpcs.GetOAuthTokenResponse, error) {
   442  	if err := req.Validate(); err != nil {
   443  		return nil, &protocolError{
   444  			Status:  400,
   445  			Message: fmt.Sprintf("Bad request: %s.", err.Error()),
   446  		}
   447  	}
   448  	generator, err := h.checkSecretAndAccount(&req.BaseRequest)
   449  	if err != nil {
   450  		return nil, err
   451  	}
   452  
   453  	// Dedup and sort scopes.
   454  	scopes := stringset.New(len(req.Scopes))
   455  	for _, s := range req.Scopes {
   456  		scopes.Add(s)
   457  	}
   458  	sortedScopes := scopes.ToSortedSlice()
   459  
   460  	// Note: this may produce ErrorWithCode.
   461  	tok, err := generator.GenerateOAuthToken(h.ctx, sortedScopes, minTokenLifetime)
   462  	if err != nil {
   463  		return nil, err
   464  	}
   465  	return &rpcs.GetOAuthTokenResponse{
   466  		AccessToken: tok.AccessToken,
   467  		Expiry:      tok.Expiry.Unix(),
   468  	}, nil
   469  }
   470  
   471  func (h *protocolHandler) handleGetIDToken(req *rpcs.GetIDTokenRequest) (*rpcs.GetIDTokenResponse, error) {
   472  	if err := req.Validate(); err != nil {
   473  		return nil, &protocolError{
   474  			Status:  400,
   475  			Message: fmt.Sprintf("Bad request: %s.", err.Error()),
   476  		}
   477  	}
   478  	generator, err := h.checkSecretAndAccount(&req.BaseRequest)
   479  	if err != nil {
   480  		return nil, err
   481  	}
   482  
   483  	// Note: this may produce ErrorWithCode.
   484  	tok, err := generator.GenerateIDToken(h.ctx, req.Audience, minTokenLifetime)
   485  	if err != nil {
   486  		return nil, err
   487  	}
   488  	return &rpcs.GetIDTokenResponse{
   489  		IDToken: tok.AccessToken, // this is actually an ID token
   490  		Expiry:  tok.Expiry.Unix(),
   491  	}, nil
   492  }