go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/auth/internal/luci_ctx.go (about)

     1  // Copyright 2017 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package internal
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto/sha256"
    21  	"encoding/hex"
    22  	"encoding/json"
    23  	"fmt"
    24  	"io"
    25  	"net/http"
    26  	"strings"
    27  	"time"
    28  
    29  	"golang.org/x/net/context/ctxhttp"
    30  	"golang.org/x/oauth2"
    31  
    32  	"go.chromium.org/luci/auth/integration/localauth/rpcs"
    33  	"go.chromium.org/luci/common/logging"
    34  	"go.chromium.org/luci/common/retry/transient"
    35  	"go.chromium.org/luci/lucictx"
    36  )
    37  
    38  type luciContextTokenProvider struct {
    39  	localAuth *lucictx.LocalAuth
    40  	email     string // an email or NoEmail
    41  	scopes    []string
    42  	audience  string // not empty iff using ID tokens
    43  	transport http.RoundTripper
    44  	cacheKey  CacheKey // used only for in-memory cache
    45  }
    46  
    47  // NewLUCIContextTokenProvider returns TokenProvider that knows how to use a
    48  // local auth server to mint tokens.
    49  //
    50  // It requires LUCI_CONTEXT["local_auth"] to be present in the 'ctx'. It's a
    51  // description of how to locate and contact the local auth server.
    52  //
    53  // See auth/integration/localauth package for the implementation of the server.
    54  func NewLUCIContextTokenProvider(ctx context.Context, scopes []string, audience string, transport http.RoundTripper) (TokenProvider, error) {
    55  	localAuth := lucictx.GetLocalAuth(ctx)
    56  	switch {
    57  	case localAuth == nil:
    58  		return nil, fmt.Errorf(`no "local_auth" in LUCI_CONTEXT`)
    59  	case localAuth.DefaultAccountId == "":
    60  		return nil, fmt.Errorf(`no "default_account_id" in LUCI_CONTEXT["local_auth"]`)
    61  	}
    62  
    63  	// Grab an email associated with default account, if any.
    64  	email := NoEmail
    65  	for _, account := range localAuth.Accounts {
    66  		if account.Id == localAuth.DefaultAccountId {
    67  			// Previous protocol version didn't expose the email, so keep the value
    68  			// as NoEmail in this case. This should be rare.
    69  			if account.Email != "" {
    70  				email = account.Email
    71  			}
    72  			break
    73  		}
    74  	}
    75  
    76  	// All authenticators share singleton in-process token cache, see
    77  	// ProcTokenCache variable in proc_cache.go.
    78  	//
    79  	// It is possible (though very unusual), for a single process to use multiple
    80  	// local auth servers (e.g. if it enters a subcontext with another "local_auth"
    81  	// value).
    82  	//
    83  	// For these reasons we use a digest of localAuth parameters as a cache key.
    84  	// It is used only in the process-local cache, the token never ends up in
    85  	// the disk cache, as indicated by Lightweight() returning true (tokens from
    86  	// such providers aren't cached on disk by Authenticator).
    87  	blob, err := json.Marshal(localAuth)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	digest := sha256.Sum256(blob)
    92  
    93  	return &luciContextTokenProvider{
    94  		localAuth: localAuth,
    95  		email:     email,
    96  		scopes:    scopes,
    97  		audience:  audience,
    98  		transport: transport,
    99  		cacheKey: CacheKey{
   100  			Key:    fmt.Sprintf("luci_ctx/%s", hex.EncodeToString(digest[:])),
   101  			Scopes: scopes,
   102  		},
   103  	}, nil
   104  }
   105  
   106  func (p *luciContextTokenProvider) RequiresInteraction() bool {
   107  	return false
   108  }
   109  
   110  func (p *luciContextTokenProvider) Lightweight() bool {
   111  	return true
   112  }
   113  
   114  func (p *luciContextTokenProvider) Email() string {
   115  	return p.email
   116  }
   117  
   118  func (p *luciContextTokenProvider) CacheKey(ctx context.Context) (*CacheKey, error) {
   119  	return &p.cacheKey, nil
   120  }
   121  
   122  func (p *luciContextTokenProvider) MintToken(ctx context.Context, base *Token) (*Token, error) {
   123  	if p.audience == "" {
   124  		return p.mintOAuthToken(ctx)
   125  	}
   126  	return p.mintIDToken(ctx)
   127  }
   128  
   129  func (p *luciContextTokenProvider) mintOAuthToken(ctx context.Context) (*Token, error) {
   130  	request := &rpcs.GetOAuthTokenRequest{
   131  		BaseRequest: rpcs.BaseRequest{
   132  			Secret:    p.localAuth.Secret,
   133  			AccountID: p.localAuth.DefaultAccountId,
   134  		},
   135  		Scopes: p.scopes,
   136  	}
   137  	response := &rpcs.GetOAuthTokenResponse{}
   138  	if err := p.doRPC(ctx, "GetOAuthToken", request, response); err != nil {
   139  		return nil, err
   140  	}
   141  	if err := p.handleRPCErr(&response.BaseResponse); err != nil {
   142  		return nil, err
   143  	}
   144  	return &Token{
   145  		Token: oauth2.Token{
   146  			AccessToken: response.AccessToken,
   147  			Expiry:      time.Unix(response.Expiry, 0).UTC(),
   148  			TokenType:   "Bearer",
   149  		},
   150  		IDToken: NoIDToken,
   151  		Email:   p.Email(),
   152  	}, nil
   153  }
   154  
   155  func (p *luciContextTokenProvider) mintIDToken(ctx context.Context) (*Token, error) {
   156  	request := &rpcs.GetIDTokenRequest{
   157  		BaseRequest: rpcs.BaseRequest{
   158  			Secret:    p.localAuth.Secret,
   159  			AccountID: p.localAuth.DefaultAccountId,
   160  		},
   161  		Audience: p.audience,
   162  	}
   163  	response := &rpcs.GetIDTokenResponse{}
   164  	if err := p.doRPC(ctx, "GetIDToken", request, response); err != nil {
   165  		return nil, err
   166  	}
   167  	if err := p.handleRPCErr(&response.BaseResponse); err != nil {
   168  		return nil, err
   169  	}
   170  	return &Token{
   171  		Token: oauth2.Token{
   172  			AccessToken: NoAccessToken,
   173  			Expiry:      time.Unix(response.Expiry, 0).UTC(),
   174  			TokenType:   "Bearer",
   175  		},
   176  		IDToken: response.IDToken,
   177  		Email:   p.Email(),
   178  	}, nil
   179  }
   180  
   181  func (p *luciContextTokenProvider) RefreshToken(ctx context.Context, prev, base *Token) (*Token, error) {
   182  	// Minting and refreshing is the same thing: a call to a local auth server.
   183  	return p.MintToken(ctx, base)
   184  }
   185  
   186  // doRPC sends a request to the local auth server and parses the response.
   187  //
   188  // Note: deadlines and retries are implemented by Authenticator. doRPC should
   189  // just make a single attempt, and mark an error as transient to trigger a
   190  // retry, if necessary.
   191  func (p *luciContextTokenProvider) doRPC(ctx context.Context, method string, req, resp any) error {
   192  	body, err := json.Marshal(req)
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	url := fmt.Sprintf("http://127.0.0.1:%d/rpc/LuciLocalAuthService.%s", p.localAuth.RpcPort, method)
   198  	logging.Debugf(ctx, "POST %s", url)
   199  	httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body))
   200  	if err != nil {
   201  		return err
   202  	}
   203  	httpReq.Header.Set("Content-Type", "application/json")
   204  
   205  	httpResp, err := ctxhttp.Do(ctx, &http.Client{Transport: p.transport}, httpReq)
   206  	if err != nil {
   207  		return transient.Tag.Apply(err)
   208  	}
   209  	defer httpResp.Body.Close()
   210  	respBody, err := io.ReadAll(httpResp.Body)
   211  	if err != nil {
   212  		return transient.Tag.Apply(err)
   213  	}
   214  
   215  	if httpResp.StatusCode != 200 {
   216  		err := fmt.Errorf("local auth - HTTP %d: %s", httpResp.StatusCode, strings.TrimSpace(string(respBody)))
   217  		if httpResp.StatusCode >= 500 {
   218  			return transient.Tag.Apply(err)
   219  		}
   220  		return err
   221  	}
   222  
   223  	return json.Unmarshal(respBody, resp)
   224  }
   225  
   226  // handleRPCErr handles `error_message` and `error_code` response fields.
   227  func (p *luciContextTokenProvider) handleRPCErr(resp *rpcs.BaseResponse) error {
   228  	if resp.ErrorMessage != "" || resp.ErrorCode != 0 {
   229  		msg := resp.ErrorMessage
   230  		if msg == "" {
   231  			msg = "unknown error"
   232  		}
   233  		return fmt.Errorf("local auth - RPC code %d: %s", resp.ErrorCode, msg)
   234  	}
   235  	return nil
   236  }