go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/auth/integration/localauth/server_test.go (about) 1 // Copyright 2017 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package localauth 16 17 import ( 18 "bytes" 19 "context" 20 "encoding/json" 21 "fmt" 22 "io" 23 "net/http" 24 "strings" 25 "testing" 26 "time" 27 28 "golang.org/x/oauth2" 29 30 "go.chromium.org/luci/common/clock" 31 "go.chromium.org/luci/common/clock/testclock" 32 "go.chromium.org/luci/common/errors" 33 "go.chromium.org/luci/common/retry/transient" 34 "go.chromium.org/luci/lucictx" 35 36 . "github.com/smartystreets/goconvey/convey" 37 . "go.chromium.org/luci/common/testing/assertions" 38 ) 39 40 type callbackGen struct { 41 email string 42 cb func(context.Context, []string, time.Duration) (*oauth2.Token, error) 43 } 44 45 func (g *callbackGen) GenerateOAuthToken(ctx context.Context, scopes []string, lifetime time.Duration) (*oauth2.Token, error) { 46 return g.cb(ctx, scopes, lifetime) 47 } 48 49 func (g *callbackGen) GenerateIDToken(ctx context.Context, audience string, lifetime time.Duration) (*oauth2.Token, error) { 50 return g.cb(ctx, []string{"audience:" + audience}, lifetime) 51 } 52 53 func (g *callbackGen) GetEmail() (string, error) { 54 return g.email, nil 55 } 56 57 func makeGenerator(email string, cb func(context.Context, []string, time.Duration) (*oauth2.Token, error)) TokenGenerator { 58 return &callbackGen{email, cb} 59 } 60 61 func TestProtocol(t *testing.T) { 62 t.Parallel() 63 64 ctx := context.Background() 65 ctx, _ = testclock.UseTime(ctx, testclock.TestRecentTimeUTC) 66 67 Convey("With server", t, func(c C) { 68 // Use channels to pass mocked requests/responses back and forth. 69 requests := make(chan []string, 10000) 70 responses := make(chan any, 1) 71 72 testGen := func(ctx context.Context, scopes []string, lifetime time.Duration) (*oauth2.Token, error) { 73 requests <- scopes 74 var resp any 75 select { 76 case resp = <-responses: 77 default: 78 c.Println("Unexpected token request") 79 return nil, fmt.Errorf("Unexpected request") 80 } 81 switch resp := resp.(type) { 82 case error: 83 return nil, resp 84 case *oauth2.Token: 85 return resp, nil 86 default: 87 panic("unknown response") 88 } 89 } 90 91 s := Server{ 92 TokenGenerators: map[string]TokenGenerator{ 93 "acc_id": makeGenerator("some@example.com", testGen), 94 "another_id": makeGenerator("another@example.com", testGen), 95 }, 96 DefaultAccountID: "acc_id", 97 } 98 p, err := s.Start(ctx) 99 So(err, ShouldBeNil) 100 defer s.Stop(ctx) 101 102 So(p.Accounts[0], ShouldResembleProto, &lucictx.LocalAuthAccount{ 103 Id: "acc_id", Email: "some@example.com", 104 }) 105 So(p.Accounts[1], ShouldResembleProto, &lucictx.LocalAuthAccount{ 106 Id: "another_id", Email: "another@example.com", 107 }) 108 So(p.DefaultAccountId, ShouldEqual, "acc_id") 109 110 goodOAuthRequest := func() *http.Request { 111 return prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{ 112 "scopes": []string{"B", "A"}, 113 "secret": p.Secret, 114 "account_id": "acc_id", 115 }) 116 } 117 118 goodIDTokRequest := func() *http.Request { 119 return prepReq(p, "/rpc/LuciLocalAuthService.GetIDToken", map[string]any{ 120 "audience": "A", 121 "secret": p.Secret, 122 "account_id": "acc_id", 123 }) 124 } 125 126 Convey("Access tokens happy path", func() { 127 responses <- &oauth2.Token{ 128 AccessToken: "tok1", 129 Expiry: clock.Now(ctx).Add(30 * time.Minute), 130 } 131 So(call(goodOAuthRequest()), ShouldEqual, `HTTP 200 (json): {"access_token":"tok1","expiry":1454474106}`) 132 So(<-requests, ShouldResemble, []string{"A", "B"}) 133 134 // application/json is also the default. 135 req := goodOAuthRequest() 136 req.Header.Del("Content-Type") 137 responses <- &oauth2.Token{ 138 AccessToken: "tok2", 139 Expiry: clock.Now(ctx).Add(30 * time.Minute), 140 } 141 So(call(req), ShouldEqual, `HTTP 200 (json): {"access_token":"tok2","expiry":1454474106}`) 142 So(<-requests, ShouldResemble, []string{"A", "B"}) 143 }) 144 145 Convey("ID tokens happy path", func() { 146 responses <- &oauth2.Token{ 147 AccessToken: "tok1", 148 Expiry: clock.Now(ctx).Add(30 * time.Minute), 149 } 150 So(call(goodIDTokRequest()), ShouldEqual, `HTTP 200 (json): {"id_token":"tok1","expiry":1454474106}`) 151 So(<-requests, ShouldResemble, []string{"audience:A"}) 152 153 // application/json is also the default. 154 req := goodIDTokRequest() 155 req.Header.Del("Content-Type") 156 responses <- &oauth2.Token{ 157 AccessToken: "tok2", 158 Expiry: clock.Now(ctx).Add(30 * time.Minute), 159 } 160 So(call(req), ShouldEqual, `HTTP 200 (json): {"id_token":"tok2","expiry":1454474106}`) 161 So(<-requests, ShouldResemble, []string{"audience:A"}) 162 }) 163 164 Convey("Panic in token generator", func() { 165 responses <- "omg, panic" 166 So(call(goodOAuthRequest()), ShouldEqual, `HTTP 500: Internal Server Error. See logs.`) 167 }) 168 169 Convey("Not POST", func() { 170 req := goodOAuthRequest() 171 req.Method = "PUT" 172 So(call(req), ShouldEqual, `HTTP 405: Expecting POST`) 173 }) 174 175 Convey("Bad URI", func() { 176 req := goodOAuthRequest() 177 req.URL.Path = "/zzz" 178 So(call(req), ShouldEqual, `HTTP 404: Expecting /rpc/LuciLocalAuthService.<method>`) 179 }) 180 181 Convey("Bad content type", func() { 182 req := goodOAuthRequest() 183 req.Header.Set("Content-Type", "bzzzz") 184 So(call(req), ShouldEqual, `HTTP 400: Expecting 'application/json' Content-Type`) 185 }) 186 187 Convey("Broken json", func() { 188 req := goodOAuthRequest() 189 190 body := `not a json` 191 req.Body = io.NopCloser(bytes.NewBufferString(body)) 192 req.ContentLength = int64(len(body)) 193 194 So(call(req), ShouldEqual, `HTTP 400: Not JSON body - invalid character 'o' in literal null (expecting 'u')`) 195 }) 196 197 Convey("Huge request", func() { 198 req := goodOAuthRequest() 199 200 body := strings.Repeat("z", 64*1024+1) 201 req.Body = io.NopCloser(bytes.NewBufferString(body)) 202 req.ContentLength = int64(len(body)) 203 204 So(call(req), ShouldEqual, `HTTP 400: Expecting 'Content-Length' header, <64Kb`) 205 }) 206 207 Convey("Unknown RPC method", func() { 208 req := prepReq(p, "/rpc/LuciLocalAuthService.UnknownMethod", map[string]any{}) 209 So(call(req), ShouldEqual, `HTTP 404: Unknown RPC method "UnknownMethod"`) 210 }) 211 212 Convey("No scopes", func() { 213 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{ 214 "secret": p.Secret, 215 "account_id": "acc_id", 216 }) 217 So(call(req), ShouldEqual, `HTTP 400: Bad request: field "scopes" is required.`) 218 }) 219 220 Convey("No audience", func() { 221 req := prepReq(p, "/rpc/LuciLocalAuthService.GetIDToken", map[string]any{ 222 "secret": p.Secret, 223 "account_id": "acc_id", 224 }) 225 So(call(req), ShouldEqual, `HTTP 400: Bad request: field "audience" is required.`) 226 }) 227 228 Convey("No secret", func() { 229 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{ 230 "scopes": []string{"B", "A"}, 231 "account_id": "acc_id", 232 }) 233 So(call(req), ShouldEqual, `HTTP 400: Bad request: field "secret" is required.`) 234 }) 235 236 Convey("Bad secret", func() { 237 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{ 238 "scopes": []string{"B", "A"}, 239 "secret": []byte{0, 1, 2, 3}, 240 "account_id": "acc_id", 241 }) 242 So(call(req), ShouldEqual, `HTTP 403: Invalid secret.`) 243 }) 244 245 Convey("No account ID", func() { 246 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{ 247 "scopes": []string{"B", "A"}, 248 "secret": p.Secret, 249 }) 250 So(call(req), ShouldEqual, `HTTP 400: Bad request: field "account_id" is required.`) 251 }) 252 253 Convey("Unknown account ID", func() { 254 req := prepReq(p, "/rpc/LuciLocalAuthService.GetOAuthToken", map[string]any{ 255 "scopes": []string{"B", "A"}, 256 "secret": p.Secret, 257 "account_id": "unknown_acc_id", 258 }) 259 So(call(req), ShouldEqual, `HTTP 404: Unrecognized account ID "unknown_acc_id".`) 260 }) 261 262 Convey("Token generator returns fatal error", func() { 263 responses <- fmt.Errorf("fatal!!111") 264 So(call(goodOAuthRequest()), ShouldEqual, `HTTP 200 (json): {"error_code":-1,"error_message":"fatal!!111"}`) 265 }) 266 267 Convey("Token generator returns ErrorWithCode", func() { 268 responses <- errWithCode{ 269 error: fmt.Errorf("with code"), 270 code: 123, 271 } 272 So(call(goodOAuthRequest()), ShouldEqual, `HTTP 200 (json): {"error_code":123,"error_message":"with code"}`) 273 }) 274 275 Convey("Token generator returns transient error", func() { 276 responses <- errors.New("transient", transient.Tag) 277 So(call(goodOAuthRequest()), ShouldEqual, `HTTP 500: Transient error - transient`) 278 }) 279 }) 280 } 281 282 type errWithCode struct { 283 error 284 code int 285 } 286 287 func (e errWithCode) Code() int { 288 return e.code 289 } 290 291 func prepReq(p *lucictx.LocalAuth, uri string, body any) *http.Request { 292 var reader io.Reader 293 isJSON := false 294 if body != nil { 295 blob, ok := body.([]byte) 296 if !ok { 297 var err error 298 blob, err = json.Marshal(body) 299 if err != nil { 300 panic(err) 301 } 302 isJSON = true 303 } 304 reader = bytes.NewReader(blob) 305 } 306 req, err := http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1:%d%s", p.RpcPort, uri), reader) 307 if err != nil { 308 panic(err) 309 } 310 if isJSON { 311 req.Header.Set("Content-Type", "application/json") 312 } 313 return req 314 } 315 316 func call(req *http.Request) any { 317 resp, err := http.DefaultClient.Do(req) 318 if err != nil { 319 panic(err) 320 } 321 defer resp.Body.Close() 322 323 blob, err := io.ReadAll(resp.Body) 324 if err != nil { 325 panic(err) 326 } 327 328 tp := "" 329 if resp.Header.Get("Content-Type") == "application/json; charset=utf-8" { 330 tp = " (json)" 331 } 332 333 return fmt.Sprintf("HTTP %d%s: %s", resp.StatusCode, tp, strings.TrimSpace(string(blob))) 334 }