go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/auth/integration/localauth/server_test.go (about)

     1  // Copyright 2017 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package localauth
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"encoding/json"
    21  	"fmt"
    22  	"io"
    23  	"net/http"
    24  	"strings"
    25  	"testing"
    26  	"time"
    27  
    28  	"golang.org/x/oauth2"
    29  
    30  	"go.chromium.org/luci/common/clock"
    31  	"go.chromium.org/luci/common/clock/testclock"
    32  	"go.chromium.org/luci/common/errors"
    33  	"go.chromium.org/luci/common/retry/transient"
    34  	"go.chromium.org/luci/lucictx"
    35  
    36  	. "github.com/smartystreets/goconvey/convey"
    37  	. "go.chromium.org/luci/common/testing/assertions"
    38  )
    39  
    40  type callbackGen struct {
    41  	email string
    42  	cb    func(context.Context, []string, time.Duration) (*oauth2.Token, error)
    43  }
    44  
    45  func (g *callbackGen) GenerateOAuthToken(ctx context.Context, scopes []string, lifetime time.Duration) (*oauth2.Token, error) {
    46  	return g.cb(ctx, scopes, lifetime)
    47  }
    48  
    49  func (g *callbackGen) GenerateIDToken(ctx context.Context, audience string, lifetime time.Duration) (*oauth2.Token, error) {
    50  	return g.cb(ctx, []string{"audience:" + audience}, lifetime)
    51  }
    52  
    53  func (g *callbackGen) GetEmail() (string, error) {
    54  	return g.email, nil
    55  }
    56  
    57  func makeGenerator(email string, cb func(context.Context, []string, time.Duration) (*oauth2.Token, error)) TokenGenerator {
    58  	return &callbackGen{email, cb}
    59  }
    60  
    61  func TestProtocol(t *testing.T) {
    62  	t.Parallel()
    63  
    64  	ctx := context.Background()
    65  	ctx, _ = testclock.UseTime(ctx, testclock.TestRecentTimeUTC)
    66  
    67  	Convey("With server", t, func(c C) {
    68  		// Use channels to pass mocked requests/responses back and forth.
    69  		requests := make(chan []string, 10000)
    70  		responses := make(chan any, 1)
    71  
    72  		testGen := func(ctx context.Context, scopes []string, lifetime time.Duration) (*oauth2.Token, error) {
    73  			requests <- scopes
    74  			var resp any
    75  			select {
    76  			case resp = <-responses:
    77  			default:
    78  				c.Println("Unexpected token request")
    79  				return nil, fmt.Errorf("Unexpected request")
    80  			}
    81  			switch resp := resp.(type) {
    82  			case error:
    83  				return nil, resp
    84  			case *oauth2.Token:
    85  				return resp, nil
    86  			default:
    87  				panic("unknown response")
    88  			}
    89  		}
    90  
    91  		s := Server{
    92  			TokenGenerators: map[string]TokenGenerator{
    93  				"acc_id":     makeGenerator("some@example.com", testGen),
    94  				"another_id": makeGenerator("another@example.com", testGen),
    95  			},
    96  			DefaultAccountID: "acc_id",
    97  		}
    98  		p, err := s.Start(ctx)
    99  		So(err, ShouldBeNil)
   100  		defer s.Stop(ctx)
   101  
   102  		So(p.Accounts[0], ShouldResembleProto, &lucictx.LocalAuthAccount{
   103  			Id: "acc_id", Email: "some@example.com",
   104  		})
   105  		So(p.Accounts[1], ShouldResembleProto, &lucictx.LocalAuthAccount{
   106  			Id: "another_id", Email: "another@example.com",
   107  		})
   108  		So(p.DefaultAccountId, ShouldEqual, "acc_id")
   109  
   110  		goodOAuthRequest := func() *http.Request {
   111  			return prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{
   112  				"scopes":     []string{"B", "A"},
   113  				"secret":     p.Secret,
   114  				"account_id": "acc_id",
   115  			})
   116  		}
   117  
   118  		goodIDTokRequest := func() *http.Request {
   119  			return prepReq(p, "/rpc/LuciLocalAuthService.GetIDToken", map[string]any{
   120  				"audience":   "A",
   121  				"secret":     p.Secret,
   122  				"account_id": "acc_id",
   123  			})
   124  		}
   125  
   126  		Convey("Access tokens happy path", func() {
   127  			responses <- &oauth2.Token{
   128  				AccessToken: "tok1",
   129  				Expiry:      clock.Now(ctx).Add(30 * time.Minute),
   130  			}
   131  			So(call(goodOAuthRequest()), ShouldEqual, `HTTP 200 (json): {"access_token":"tok1","expiry":1454474106}`)
   132  			So(<-requests, ShouldResemble, []string{"A", "B"})
   133  
   134  			// application/json is also the default.
   135  			req := goodOAuthRequest()
   136  			req.Header.Del("Content-Type")
   137  			responses <- &oauth2.Token{
   138  				AccessToken: "tok2",
   139  				Expiry:      clock.Now(ctx).Add(30 * time.Minute),
   140  			}
   141  			So(call(req), ShouldEqual, `HTTP 200 (json): {"access_token":"tok2","expiry":1454474106}`)
   142  			So(<-requests, ShouldResemble, []string{"A", "B"})
   143  		})
   144  
   145  		Convey("ID tokens happy path", func() {
   146  			responses <- &oauth2.Token{
   147  				AccessToken: "tok1",
   148  				Expiry:      clock.Now(ctx).Add(30 * time.Minute),
   149  			}
   150  			So(call(goodIDTokRequest()), ShouldEqual, `HTTP 200 (json): {"id_token":"tok1","expiry":1454474106}`)
   151  			So(<-requests, ShouldResemble, []string{"audience:A"})
   152  
   153  			// application/json is also the default.
   154  			req := goodIDTokRequest()
   155  			req.Header.Del("Content-Type")
   156  			responses <- &oauth2.Token{
   157  				AccessToken: "tok2",
   158  				Expiry:      clock.Now(ctx).Add(30 * time.Minute),
   159  			}
   160  			So(call(req), ShouldEqual, `HTTP 200 (json): {"id_token":"tok2","expiry":1454474106}`)
   161  			So(<-requests, ShouldResemble, []string{"audience:A"})
   162  		})
   163  
   164  		Convey("Panic in token generator", func() {
   165  			responses <- "omg, panic"
   166  			So(call(goodOAuthRequest()), ShouldEqual, `HTTP 500: Internal Server Error. See logs.`)
   167  		})
   168  
   169  		Convey("Not POST", func() {
   170  			req := goodOAuthRequest()
   171  			req.Method = "PUT"
   172  			So(call(req), ShouldEqual, `HTTP 405: Expecting POST`)
   173  		})
   174  
   175  		Convey("Bad URI", func() {
   176  			req := goodOAuthRequest()
   177  			req.URL.Path = "/zzz"
   178  			So(call(req), ShouldEqual, `HTTP 404: Expecting /rpc/LuciLocalAuthService.<method>`)
   179  		})
   180  
   181  		Convey("Bad content type", func() {
   182  			req := goodOAuthRequest()
   183  			req.Header.Set("Content-Type", "bzzzz")
   184  			So(call(req), ShouldEqual, `HTTP 400: Expecting 'application/json' Content-Type`)
   185  		})
   186  
   187  		Convey("Broken json", func() {
   188  			req := goodOAuthRequest()
   189  
   190  			body := `not a json`
   191  			req.Body = io.NopCloser(bytes.NewBufferString(body))
   192  			req.ContentLength = int64(len(body))
   193  
   194  			So(call(req), ShouldEqual, `HTTP 400: Not JSON body - invalid character 'o' in literal null (expecting 'u')`)
   195  		})
   196  
   197  		Convey("Huge request", func() {
   198  			req := goodOAuthRequest()
   199  
   200  			body := strings.Repeat("z", 64*1024+1)
   201  			req.Body = io.NopCloser(bytes.NewBufferString(body))
   202  			req.ContentLength = int64(len(body))
   203  
   204  			So(call(req), ShouldEqual, `HTTP 400: Expecting 'Content-Length' header, <64Kb`)
   205  		})
   206  
   207  		Convey("Unknown RPC method", func() {
   208  			req := prepReq(p, "/rpc/LuciLocalAuthService.UnknownMethod", map[string]any{})
   209  			So(call(req), ShouldEqual, `HTTP 404: Unknown RPC method "UnknownMethod"`)
   210  		})
   211  
   212  		Convey("No scopes", func() {
   213  			req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{
   214  				"secret":     p.Secret,
   215  				"account_id": "acc_id",
   216  			})
   217  			So(call(req), ShouldEqual, `HTTP 400: Bad request: field "scopes" is required.`)
   218  		})
   219  
   220  		Convey("No audience", func() {
   221  			req := prepReq(p, "/rpc/LuciLocalAuthService.GetIDToken", map[string]any{
   222  				"secret":     p.Secret,
   223  				"account_id": "acc_id",
   224  			})
   225  			So(call(req), ShouldEqual, `HTTP 400: Bad request: field "audience" is required.`)
   226  		})
   227  
   228  		Convey("No secret", func() {
   229  			req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{
   230  				"scopes":     []string{"B", "A"},
   231  				"account_id": "acc_id",
   232  			})
   233  			So(call(req), ShouldEqual, `HTTP 400: Bad request: field "secret" is required.`)
   234  		})
   235  
   236  		Convey("Bad secret", func() {
   237  			req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{
   238  				"scopes":     []string{"B", "A"},
   239  				"secret":     []byte{0, 1, 2, 3},
   240  				"account_id": "acc_id",
   241  			})
   242  			So(call(req), ShouldEqual, `HTTP 403: Invalid secret.`)
   243  		})
   244  
   245  		Convey("No account ID", func() {
   246  			req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{
   247  				"scopes": []string{"B", "A"},
   248  				"secret": p.Secret,
   249  			})
   250  			So(call(req), ShouldEqual, `HTTP 400: Bad request: field "account_id" is required.`)
   251  		})
   252  
   253  		Convey("Unknown account ID", func() {
   254  			req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{
   255  				"scopes":     []string{"B", "A"},
   256  				"secret":     p.Secret,
   257  				"account_id": "unknown_acc_id",
   258  			})
   259  			So(call(req), ShouldEqual, `HTTP 404: Unrecognized account ID "unknown_acc_id".`)
   260  		})
   261  
   262  		Convey("Token generator returns fatal error", func() {
   263  			responses <- fmt.Errorf("fatal!!111")
   264  			So(call(goodOAuthRequest()), ShouldEqual, `HTTP 200 (json): {"error_code":-1,"error_message":"fatal!!111"}`)
   265  		})
   266  
   267  		Convey("Token generator returns ErrorWithCode", func() {
   268  			responses <- errWithCode{
   269  				error: fmt.Errorf("with code"),
   270  				code:  123,
   271  			}
   272  			So(call(goodOAuthRequest()), ShouldEqual, `HTTP 200 (json): {"error_code":123,"error_message":"with code"}`)
   273  		})
   274  
   275  		Convey("Token generator returns transient error", func() {
   276  			responses <- errors.New("transient", transient.Tag)
   277  			So(call(goodOAuthRequest()), ShouldEqual, `HTTP 500: Transient error - transient`)
   278  		})
   279  	})
   280  }
   281  
   282  type errWithCode struct {
   283  	error
   284  	code int
   285  }
   286  
   287  func (e errWithCode) Code() int {
   288  	return e.code
   289  }
   290  
   291  func prepReq(p *lucictx.LocalAuth, uri string, body any) *http.Request {
   292  	var reader io.Reader
   293  	isJSON := false
   294  	if body != nil {
   295  		blob, ok := body.([]byte)
   296  		if !ok {
   297  			var err error
   298  			blob, err = json.Marshal(body)
   299  			if err != nil {
   300  				panic(err)
   301  			}
   302  			isJSON = true
   303  		}
   304  		reader = bytes.NewReader(blob)
   305  	}
   306  	req, err := http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1:%d%s", p.RpcPort, uri), reader)
   307  	if err != nil {
   308  		panic(err)
   309  	}
   310  	if isJSON {
   311  		req.Header.Set("Content-Type", "application/json")
   312  	}
   313  	return req
   314  }
   315  
   316  func call(req *http.Request) any {
   317  	resp, err := http.DefaultClient.Do(req)
   318  	if err != nil {
   319  		panic(err)
   320  	}
   321  	defer resp.Body.Close()
   322  
   323  	blob, err := io.ReadAll(resp.Body)
   324  	if err != nil {
   325  		panic(err)
   326  	}
   327  
   328  	tp := ""
   329  	if resp.Header.Get("Content-Type") == "application/json; charset=utf-8" {
   330  		tp = " (json)"
   331  	}
   332  
   333  	return fmt.Sprintf("HTTP %d%s: %s", resp.StatusCode, tp, strings.TrimSpace(string(blob)))
   334  }