go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/auth/internal/luci_ctx_test.go (about)

     1  // Copyright 2017 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package internal
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"net"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"testing"
    24  	"time"
    25  
    26  	"golang.org/x/oauth2"
    27  
    28  	"go.chromium.org/luci/auth/integration/localauth/rpcs"
    29  	"go.chromium.org/luci/common/retry/transient"
    30  	"go.chromium.org/luci/lucictx"
    31  
    32  	. "github.com/smartystreets/goconvey/convey"
    33  	. "go.chromium.org/luci/common/testing/assertions"
    34  )
    35  
    36  func TestLUCIContextProvider(t *testing.T) {
    37  	t.Parallel()
    38  
    39  	// Clear any existing LUCI_CONTEXT["local_auth"], it may be present if the
    40  	// test runs on a LUCI bot.
    41  	baseCtx := lucictx.Set(context.Background(), "local_auth", nil)
    42  
    43  	Convey("Requires local_auth", t, func() {
    44  		_, err := NewLUCIContextTokenProvider(baseCtx, []string{"A"}, "", http.DefaultTransport)
    45  		So(err, ShouldErrLike, `no "local_auth" in LUCI_CONTEXT`)
    46  	})
    47  
    48  	Convey("Requires default_account_id", t, func() {
    49  		ctx := lucictx.SetLocalAuth(baseCtx, &lucictx.LocalAuth{
    50  			Accounts: []*lucictx.LocalAuthAccount{{Id: "zzz"}},
    51  		})
    52  		_, err := NewLUCIContextTokenProvider(ctx, []string{"A"}, "", http.DefaultTransport)
    53  		So(err, ShouldErrLike, `no "default_account_id"`)
    54  	})
    55  
    56  	Convey("With mock server", t, func(c C) {
    57  		requests := make(chan any, 10000)
    58  		responses := make(chan any, 1)
    59  		ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    60  			c.So(r.Method, ShouldEqual, "POST")
    61  			c.So(r.Header.Get("Content-Type"), ShouldEqual, "application/json")
    62  
    63  			switch r.RequestURI {
    64  			case "/rpc/LuciLocalAuthService.GetOAuthToken":
    65  				req := rpcs.GetOAuthTokenRequest{}
    66  				c.So(json.NewDecoder(r.Body).Decode(&req), ShouldBeNil)
    67  				requests <- req
    68  			case "/rpc/LuciLocalAuthService.GetIDToken":
    69  				req := rpcs.GetIDTokenRequest{}
    70  				c.So(json.NewDecoder(r.Body).Decode(&req), ShouldBeNil)
    71  				requests <- req
    72  			default:
    73  				http.Error(w, "Unknown method", 404)
    74  			}
    75  
    76  			var resp any
    77  			select {
    78  			case resp = <-responses:
    79  			default:
    80  				panic("Unexpected request")
    81  			}
    82  
    83  			switch resp := resp.(type) {
    84  			case rpcs.GetOAuthTokenResponse:
    85  				w.WriteHeader(200)
    86  				c.So(json.NewEncoder(w).Encode(resp), ShouldBeNil)
    87  			case rpcs.GetIDTokenResponse:
    88  				w.WriteHeader(200)
    89  				c.So(json.NewEncoder(w).Encode(resp), ShouldBeNil)
    90  			case int:
    91  				http.Error(w, http.StatusText(resp), resp)
    92  			default:
    93  				panic("unexpected response type")
    94  			}
    95  		}))
    96  		defer ts.Close()
    97  
    98  		ctx := lucictx.SetLocalAuth(baseCtx, &lucictx.LocalAuth{
    99  			RpcPort: uint32(ts.Listener.Addr().(*net.TCPAddr).Port),
   100  			Secret:  []byte("zekret"),
   101  			Accounts: []*lucictx.LocalAuthAccount{
   102  				{Id: "acc_id", Email: "some-acc-email@example.com"},
   103  			},
   104  			DefaultAccountId: "acc_id",
   105  		})
   106  
   107  		Convey("Access tokens", func() {
   108  			p, err := NewLUCIContextTokenProvider(ctx, []string{"B", "A"}, "", http.DefaultTransport)
   109  			So(err, ShouldBeNil)
   110  
   111  			Convey("Happy path", func() {
   112  				responses <- rpcs.GetOAuthTokenResponse{
   113  					AccessToken: "zzz",
   114  					Expiry:      1487456796,
   115  				}
   116  
   117  				tok, err := p.MintToken(ctx, nil)
   118  				So(err, ShouldBeNil)
   119  				So(tok, ShouldResemble, &Token{
   120  					Token: oauth2.Token{
   121  						AccessToken: "zzz",
   122  						TokenType:   "Bearer",
   123  						Expiry:      time.Unix(1487456796, 0).UTC(),
   124  					},
   125  					IDToken: NoIDToken,
   126  					Email:   "some-acc-email@example.com",
   127  				})
   128  
   129  				So(<-requests, ShouldResemble, rpcs.GetOAuthTokenRequest{
   130  					BaseRequest: rpcs.BaseRequest{
   131  						Secret:    []byte("zekret"),
   132  						AccountID: "acc_id",
   133  					},
   134  					Scopes: []string{"B", "A"},
   135  				})
   136  			})
   137  
   138  			Convey("HTTP 500", func() {
   139  				responses <- 500
   140  				tok, err := p.MintToken(ctx, nil)
   141  				So(tok, ShouldBeNil)
   142  				So(err, ShouldErrLike, `local auth - HTTP 500`)
   143  				So(transient.Tag.In(err), ShouldBeTrue)
   144  			})
   145  
   146  			Convey("HTTP 403", func() {
   147  				responses <- 403
   148  				tok, err := p.MintToken(ctx, nil)
   149  				So(tok, ShouldBeNil)
   150  				So(err, ShouldErrLike, `local auth - HTTP 403`)
   151  				So(transient.Tag.In(err), ShouldBeFalse)
   152  			})
   153  
   154  			Convey("RPC level error", func() {
   155  				responses <- rpcs.GetOAuthTokenResponse{
   156  					BaseResponse: rpcs.BaseResponse{
   157  						ErrorCode:    123,
   158  						ErrorMessage: "omg, error",
   159  					},
   160  				}
   161  				tok, err := p.MintToken(ctx, nil)
   162  				So(tok, ShouldBeNil)
   163  				So(err, ShouldErrLike, `local auth - RPC code 123: omg, error`)
   164  				So(transient.Tag.In(err), ShouldBeFalse)
   165  			})
   166  		})
   167  
   168  		Convey("ID tokens", func() {
   169  			p, err := NewLUCIContextTokenProvider(ctx, []string{"audience:test-aud"}, "test-aud", http.DefaultTransport)
   170  			So(err, ShouldBeNil)
   171  
   172  			Convey("Happy path", func() {
   173  				responses <- rpcs.GetIDTokenResponse{
   174  					IDToken: "zzz",
   175  					Expiry:  1487456796,
   176  				}
   177  
   178  				tok, err := p.MintToken(ctx, nil)
   179  				So(err, ShouldBeNil)
   180  				So(tok, ShouldResemble, &Token{
   181  					Token: oauth2.Token{
   182  						AccessToken: NoAccessToken,
   183  						TokenType:   "Bearer",
   184  						Expiry:      time.Unix(1487456796, 0).UTC(),
   185  					},
   186  					IDToken: "zzz",
   187  					Email:   "some-acc-email@example.com",
   188  				})
   189  
   190  				So(<-requests, ShouldResemble, rpcs.GetIDTokenRequest{
   191  					BaseRequest: rpcs.BaseRequest{
   192  						Secret:    []byte("zekret"),
   193  						AccountID: "acc_id",
   194  					},
   195  					Audience: "test-aud",
   196  				})
   197  			})
   198  		})
   199  	})
   200  }