go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/auth/integration/firebase/firebase.go (about)

     1  // Copyright 2018 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package firebase implements an auth server that allows firebase-tools
    16  // to use an exposed OAuth2 TokenSource for auth.
    17  // See firebase auth documentation at https://github.com/firebase/firebase-tools
    18  // and auth implementation https://github.com/firebase/firebase-tools/blob/9422490bd87e934a097a110f77eddac799d965a4/lib/auth.js
    19  package firebase
    20  
    21  import (
    22  	"context"
    23  	"encoding/json"
    24  	"fmt"
    25  	"net"
    26  	"net/http"
    27  	"sync"
    28  	"time"
    29  
    30  	"golang.org/x/oauth2"
    31  
    32  	"go.chromium.org/luci/common/clock"
    33  	"go.chromium.org/luci/common/errors"
    34  	"go.chromium.org/luci/common/logging"
    35  	"go.chromium.org/luci/common/retry/transient"
    36  	"go.chromium.org/luci/common/runtime/paniccatcher"
    37  
    38  	"go.chromium.org/luci/auth/integration/internal/localsrv"
    39  )
    40  
    41  // Server runs a local server that handles requests to token_uri.
    42  type Server struct {
    43  	// Source is used to obtain OAuth2 tokens.
    44  	Source oauth2.TokenSource
    45  	// Port is a local TCP port to bind to or 0 to allow the OS to pick one.
    46  	Port int
    47  
    48  	srv localsrv.Server
    49  }
    50  
    51  // Start launches background goroutine with the serving loop.
    52  //
    53  // The provided context is used as base context for request handlers and for
    54  // logging. The server must be eventually stopped with Stop().
    55  //
    56  // Returns the root URL of a local server to use as FIREBASE_TOKEN_URL.
    57  func (s *Server) Start(ctx context.Context) (string, error) {
    58  	// Launch the server to get the port number.
    59  	addr, err := s.srv.Start(ctx, "firebase-auth", s.Port, func(c context.Context, l net.Listener, wg *sync.WaitGroup) error {
    60  		return s.serve(c, l, wg)
    61  	})
    62  	if err != nil {
    63  		return "", errors.Annotate(err, "failed to start the server").Err()
    64  	}
    65  	return fmt.Sprintf("http://%s", addr), nil
    66  }
    67  
    68  // Stop closes the listening socket, notifies pending requests to abort and
    69  // stops the internal serving goroutine.
    70  //
    71  // Safe to call multiple times. Once stopped, the server cannot be started again
    72  // (make a new instance of Server instead).
    73  //
    74  // Uses the given context for the deadline when waiting for the serving loop
    75  // to stop.
    76  func (s *Server) Stop(ctx context.Context) error {
    77  	return s.srv.Stop(ctx)
    78  }
    79  
    80  // serve runs the serving loop.
    81  func (s *Server) serve(ctx context.Context, l net.Listener, wg *sync.WaitGroup) error {
    82  	mux := http.NewServeMux()
    83  
    84  	mux.Handle("/oauth2/v3/token", &handler{ctx, wg, func(rw http.ResponseWriter, r *http.Request) {
    85  		err := s.handleTokenRequest(rw, r)
    86  
    87  		code := 0
    88  		msg := ""
    89  		if transient.Tag.In(err) {
    90  			code = http.StatusInternalServerError
    91  			msg = fmt.Sprintf("Transient error - %s", err)
    92  		} else if err != nil {
    93  			code = http.StatusBadRequest
    94  			msg = fmt.Sprintf("Bad request - %s", err)
    95  		}
    96  
    97  		if code != 0 {
    98  			logging.Errorf(ctx, "%s", msg)
    99  			http.Error(rw, msg, code)
   100  		}
   101  	}})
   102  
   103  	srv := http.Server{Handler: mux}
   104  	return srv.Serve(l)
   105  }
   106  
   107  // handleTokenRequest handles the OAuth2 flow.
   108  //
   109  // The body of the request is documented here (among many other places):
   110  //
   111  //	https://developers.google.com/identity/protocols/OAuth2InstalledApp#offline
   112  //
   113  // We ignore client_id and client_secret, since we aren't really running OAuth2.
   114  func (s *Server) handleTokenRequest(rw http.ResponseWriter, r *http.Request) error {
   115  	ctx := r.Context()
   116  
   117  	// We support only refreshing access token via 'refresh_token' grant.
   118  	if r.PostFormValue("grant_type") != "refresh_token" {
   119  		return fmt.Errorf("expecting 'refresh_token' grant type")
   120  	}
   121  
   122  	// Grab an access token through the source and return it.
   123  	tok, err := s.Source.Token()
   124  	if err != nil {
   125  		return err
   126  	}
   127  	rw.Header().Set("Content-Type", "application/json")
   128  	return json.NewEncoder(rw).Encode(map[string]any{
   129  		"access_token": tok.AccessToken,
   130  		"expires_in":   clock.Until(ctx, tok.Expiry) / time.Second,
   131  		"token_type":   "Bearer",
   132  	})
   133  }
   134  
   135  // handler implements http.Handler by wrapping the given handler with some
   136  // housekeeping stuff.
   137  type handler struct {
   138  	ctx     context.Context
   139  	wg      *sync.WaitGroup
   140  	handler http.HandlerFunc
   141  }
   142  
   143  func (h *handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
   144  	h.wg.Add(1)
   145  	defer h.wg.Done()
   146  
   147  	defer paniccatcher.Catch(func(p *paniccatcher.Panic) {
   148  		logging.Fields{
   149  			"panic.error": p.Reason,
   150  		}.Errorf(h.ctx, "Caught panic during handling of %q: %s\n%s", r.RequestURI, p.Reason, p.Stack)
   151  		http.Error(rw, "Internal Server Error. See logs.", http.StatusInternalServerError)
   152  	})
   153  
   154  	logging.Debugf(h.ctx, "Handling %s %s", r.Method, r.RequestURI)
   155  	h.handler(rw, r.WithContext(h.ctx))
   156  }