github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/server/authentication_test.go (about) 1 // Copyright 2015 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package server 12 13 import ( 14 "bytes" 15 "context" 16 "crypto/sha256" 17 "crypto/tls" 18 gosql "database/sql" 19 "fmt" 20 "io/ioutil" 21 "net/http" 22 "net/http/cookiejar" 23 "net/url" 24 "testing" 25 "time" 26 27 "github.com/cockroachdb/cockroach/pkg/base" 28 "github.com/cockroachdb/cockroach/pkg/gossip" 29 "github.com/cockroachdb/cockroach/pkg/kv/kvserver" 30 "github.com/cockroachdb/cockroach/pkg/kv/kvserver/closedts/ctpb" 31 "github.com/cockroachdb/cockroach/pkg/roachpb" 32 "github.com/cockroachdb/cockroach/pkg/security" 33 "github.com/cockroachdb/cockroach/pkg/server/debug" 34 "github.com/cockroachdb/cockroach/pkg/server/serverpb" 35 "github.com/cockroachdb/cockroach/pkg/sql/execinfrapb" 36 "github.com/cockroachdb/cockroach/pkg/testutils" 37 "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" 38 "github.com/cockroachdb/cockroach/pkg/ts" 39 "github.com/cockroachdb/cockroach/pkg/ts/tspb" 40 "github.com/cockroachdb/cockroach/pkg/util" 41 "github.com/cockroachdb/cockroach/pkg/util/httputil" 42 "github.com/cockroachdb/cockroach/pkg/util/leaktest" 43 "github.com/cockroachdb/cockroach/pkg/util/timeutil" 44 "github.com/cockroachdb/errors" 45 "github.com/gogo/protobuf/jsonpb" 46 "github.com/lib/pq" 47 "golang.org/x/crypto/bcrypt" 48 "google.golang.org/grpc" 49 "google.golang.org/grpc/credentials" 50 ) 51 52 type ctxI interface { 53 GetHTTPClient() (http.Client, error) 54 HTTPRequestScheme() string 55 } 56 57 var _ ctxI = insecureCtx{} 58 var _ ctxI = (*base.Config)(nil) 59 60 type insecureCtx struct{} 61 62 func (insecureCtx) GetHTTPClient() (http.Client, error) { 63 return http.Client{ 64 Transport: &http.Transport{ 65 TLSClientConfig: &tls.Config{ 66 InsecureSkipVerify: true, 67 }, 68 }, 69 }, nil 70 } 71 72 func (insecureCtx) HTTPRequestScheme() string { 73 return "https" 74 } 75 76 // Verify client certificate enforcement and user whitelisting. 77 func TestSSLEnforcement(t *testing.T) { 78 defer leaktest.AfterTest(t)() 79 s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ 80 // This test is verifying the (unimplemented) authentication of SSL 81 // client certificates over HTTP endpoints. Web session authentication 82 // is disabled in order to avoid the need to authenticate the individual 83 // clients being instantiated. 84 DisableWebSessionAuthentication: true, 85 }) 86 defer s.Stopper().Stop(context.Background()) 87 88 // HTTPS with client certs for security.RootUser. 89 rootCertsContext := testutils.NewTestBaseContext(security.RootUser) 90 // HTTPS with client certs for security.NodeUser. 91 nodeCertsContext := testutils.NewNodeTestBaseContext() 92 // HTTPS with client certs for TestUser. 93 testCertsContext := testutils.NewTestBaseContext(TestUser) 94 // HTTPS without client certs. The user does not matter. 95 noCertsContext := insecureCtx{} 96 // Plain http. 97 insecureContext := testutils.NewTestBaseContext(TestUser) 98 insecureContext.Insecure = true 99 100 kvGet := &roachpb.GetRequest{} 101 kvGet.Key = roachpb.Key("/") 102 103 for _, tc := range []struct { 104 path string 105 ctx ctxI 106 code int // http response code 107 }{ 108 // Health endpoint is special-cased; allowed to serve on HTTP. 109 {"/health", insecureContext, http.StatusOK}, 110 111 // /ui/: basic file server: no auth. 112 {"", rootCertsContext, http.StatusOK}, 113 {"", nodeCertsContext, http.StatusOK}, 114 {"", testCertsContext, http.StatusOK}, 115 {"", noCertsContext, http.StatusOK}, 116 {"", insecureContext, http.StatusTemporaryRedirect}, 117 118 // /_admin/: server.adminServer: no auth. 119 {adminPrefix + "health", rootCertsContext, http.StatusOK}, 120 {adminPrefix + "health", nodeCertsContext, http.StatusOK}, 121 {adminPrefix + "health", testCertsContext, http.StatusOK}, 122 {adminPrefix + "health", noCertsContext, http.StatusOK}, 123 {adminPrefix + "health", insecureContext, http.StatusTemporaryRedirect}, 124 125 // /debug/: server.adminServer: no auth. 126 {debug.Endpoint + "vars", rootCertsContext, http.StatusOK}, 127 {debug.Endpoint + "vars", nodeCertsContext, http.StatusOK}, 128 {debug.Endpoint + "vars", testCertsContext, http.StatusOK}, 129 {debug.Endpoint + "vars", noCertsContext, http.StatusOK}, 130 {debug.Endpoint + "vars", insecureContext, http.StatusTemporaryRedirect}, 131 132 // /_status/nodes: server.statusServer: no auth. 133 {statusPrefix + "nodes", rootCertsContext, http.StatusOK}, 134 {statusPrefix + "nodes", nodeCertsContext, http.StatusOK}, 135 {statusPrefix + "nodes", testCertsContext, http.StatusOK}, 136 {statusPrefix + "nodes", noCertsContext, http.StatusOK}, 137 {statusPrefix + "nodes", insecureContext, http.StatusTemporaryRedirect}, 138 139 // /ts/: ts.Server: no auth. 140 {ts.URLPrefix, rootCertsContext, http.StatusNotFound}, 141 {ts.URLPrefix, nodeCertsContext, http.StatusNotFound}, 142 {ts.URLPrefix, testCertsContext, http.StatusNotFound}, 143 {ts.URLPrefix, noCertsContext, http.StatusNotFound}, 144 {ts.URLPrefix, insecureContext, http.StatusTemporaryRedirect}, 145 } { 146 t.Run("", func(t *testing.T) { 147 client, err := tc.ctx.GetHTTPClient() 148 if err != nil { 149 t.Fatal(err) 150 } 151 // Avoid automatically following redirects. 152 client.CheckRedirect = func(*http.Request, []*http.Request) error { 153 return http.ErrUseLastResponse 154 } 155 url := url.URL{ 156 Scheme: tc.ctx.HTTPRequestScheme(), 157 Host: s.(*TestServer).Cfg.HTTPAddr, 158 Path: tc.path, 159 } 160 resp, err := client.Get(url.String()) 161 if err != nil { 162 t.Fatal(err) 163 } 164 165 defer resp.Body.Close() 166 if resp.StatusCode != tc.code { 167 t.Errorf("expected status code %d, got %d", tc.code, resp.StatusCode) 168 u, err := resp.Location() 169 t.Errorf("orig=%s url=%s err=%v", tc.path, u, err) 170 } 171 }) 172 } 173 } 174 175 func TestVerifyPassword(t *testing.T) { 176 defer leaktest.AfterTest(t)() 177 178 ctx := context.Background() 179 s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) 180 defer s.Stopper().Stop(ctx) 181 182 ts := s.(*TestServer) 183 184 if util.RaceEnabled { 185 // The default bcrypt cost makes this test approximately 30s slower when the 186 // race detector is on. 187 defer func(prev int) { security.BcryptCost = prev }(security.BcryptCost) 188 security.BcryptCost = bcrypt.MinCost 189 } 190 191 //location is used for timezone testing. 192 shanghaiLoc, err := time.LoadLocation("Asia/Shanghai") 193 if err != nil { 194 t.Fatal(err) 195 } 196 197 for _, user := range []struct { 198 username string 199 password string 200 loginFlag string 201 validUntilClause string 202 qargs []interface{} 203 }{ 204 {"azure_diamond", "hunter2", "", "", nil}, 205 {"druidia", "12345", "", "", nil}, 206 207 {"richardc", "12345", "NOLOGIN", "", nil}, 208 {"before_epoch", "12345", "", "VALID UNTIL '1969-01-01'", nil}, 209 {"epoch", "12345", "", "VALID UNTIL '1970-01-01'", nil}, 210 {"cockroach", "12345", "", "VALID UNTIL '2100-01-01'", nil}, 211 {"cthon98", "12345", "", "VALID UNTIL NULL", nil}, 212 213 {"toolate", "12345", "", "VALID UNTIL $1", 214 []interface{}{timeutil.Now().Add(-10 * time.Minute)}}, 215 {"timelord", "12345", "", "VALID UNTIL $1", 216 []interface{}{timeutil.Now().Add(59 * time.Minute).In(shanghaiLoc)}}, 217 } { 218 cmd := fmt.Sprintf( 219 "CREATE USER %s WITH PASSWORD '%s' %s %s", 220 user.username, user.password, user.loginFlag, user.validUntilClause) 221 222 if _, err := db.Exec(cmd, user.qargs...); err != nil { 223 t.Fatalf("failed to create user: %s", err) 224 } 225 } 226 227 for _, tc := range []struct { 228 username string 229 password string 230 shouldAuthenticate bool 231 expectedErrString string 232 }{ 233 {"azure_diamond", "hunter2", true, ""}, 234 {"azure_diamond", "hunter", false, "crypto/bcrypt"}, 235 {"azure_diamond", "", false, "crypto/bcrypt"}, 236 {"azure_diamond", "🍦", false, "crypto/bcrypt"}, 237 {"azure_diamond", "hunter2345", false, "crypto/bcrypt"}, 238 {"azure_diamond", "shunter2", false, "crypto/bcrypt"}, 239 {"azure_diamond", "12345", false, "crypto/bcrypt"}, 240 {"azure_diamond", "*******", false, "crypto/bcrypt"}, 241 {"druidia", "12345", true, ""}, 242 {"druidia", "hunter2", false, "crypto/bcrypt"}, 243 {"root", "", false, "crypto/bcrypt"}, 244 {"", "", false, "does not exist"}, 245 {"doesntexist", "zxcvbn", false, "does not exist"}, 246 247 {"richardc", "12345", false, 248 "richardc does not have login privilege"}, 249 {"before_epoch", "12345", false, ""}, 250 {"epoch", "12345", false, ""}, 251 {"cockroach", "12345", true, ""}, 252 {"toolate", "12345", false, ""}, 253 {"timelord", "12345", true, ""}, 254 {"cthon98", "12345", true, ""}, 255 } { 256 t.Run("", func(t *testing.T) { 257 valid, expired, err := ts.authentication.verifyPassword(context.Background(), tc.username, tc.password) 258 if err != nil { 259 t.Errorf( 260 "credentials %s/%s failed with error %s, wanted no error", 261 tc.username, 262 tc.password, 263 err, 264 ) 265 } 266 if valid && !expired != tc.shouldAuthenticate { 267 t.Errorf( 268 "credentials %s/%s valid = %t, wanted %t", 269 tc.username, 270 tc.password, 271 valid, 272 tc.shouldAuthenticate, 273 ) 274 } 275 }) 276 } 277 } 278 279 func TestCreateSession(t *testing.T) { 280 defer leaktest.AfterTest(t)() 281 s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) 282 defer s.Stopper().Stop(context.Background()) 283 ts := s.(*TestServer) 284 285 username := "testUser" 286 287 // Create an authentication, noting the time before and after creation. This 288 // lets us ensure that the timestamps created are accurate. 289 timeBoundBefore := ts.clock.PhysicalTime() 290 id, origSecret, err := ts.authentication.newAuthSession(context.Background(), username) 291 if err != nil { 292 t.Fatalf("error creating auth session: %s", err) 293 } 294 timeBoundAfter := ts.clock.PhysicalTime() 295 296 // Query fields from created session. 297 query := ` 298 SELECT "hashedSecret", "username", "createdAt", "lastUsedAt", "expiresAt", "revokedAt", "auditInfo" 299 FROM system.web_sessions 300 WHERE id = $1` 301 302 result := db.QueryRow(query, id) 303 var ( 304 sessHashedSecret []byte 305 sessUsername string 306 sessCreated time.Time 307 sessLastUsed time.Time 308 sessExpires time.Time 309 sessRevoked pq.NullTime 310 sessAuditInfo gosql.NullString 311 ) 312 if err := result.Scan( 313 &sessHashedSecret, 314 &sessUsername, 315 &sessCreated, 316 &sessLastUsed, 317 &sessExpires, 318 &sessRevoked, 319 &sessAuditInfo, 320 ); err != nil { 321 t.Fatalf("error querying created auth session: %s", err) 322 } 323 324 // Verify hashed secret matches original secret 325 hasher := sha256.New() 326 _, _ = hasher.Write(origSecret) 327 hashedSecret := hasher.Sum(nil) 328 if !bytes.Equal(sessHashedSecret, hashedSecret) { 329 t.Fatalf("hashed value of secret: \n%#v\ncomputed as: \n%#v\nwanted: \n%#v", origSecret, hashedSecret, sessHashedSecret) 330 } 331 332 // Username. 333 if a, e := sessUsername, username; a != e { 334 t.Fatalf("session username got %s, wanted %s", a, e) 335 } 336 337 // Timestamps. 338 verifyTimestamp := func(actual time.Time, early time.Time, late time.Time) error { 339 if actual.Before(early) { 340 return errors.Errorf("time %s was before early bound %s", actual, early) 341 } 342 if late.Before(actual) { 343 return errors.Errorf("time %s was after late bound %s", actual, late) 344 } 345 return nil 346 } 347 348 if err := verifyTimestamp(sessCreated, timeBoundBefore, timeBoundAfter); err != nil { 349 t.Fatalf("bad createdAt timestamp: %s", err) 350 } 351 if err := verifyTimestamp(sessLastUsed, timeBoundBefore, timeBoundAfter); err != nil { 352 t.Fatalf("bad lastUsedAt timestamp: %s", err) 353 } 354 timeout := webSessionTimeout.Get(&s.ClusterSettings().SV) 355 if err := verifyTimestamp( 356 sessExpires, timeBoundBefore.Add(timeout), timeBoundAfter.Add(timeout), 357 ); err != nil { 358 t.Fatalf("bad expiresAt timestamp: %s", err) 359 } 360 361 // Null fields 362 if sessRevoked.Valid { 363 t.Fatalf("sess had revokedAt timestamp %s, wanted null", sessRevoked.Time) 364 } 365 if sessAuditInfo.Valid { 366 t.Fatalf("sess had auditInfo %s, wanted null", sessAuditInfo.String) 367 } 368 } 369 370 func TestVerifySession(t *testing.T) { 371 defer leaktest.AfterTest(t)() 372 s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) 373 defer s.Stopper().Stop(context.Background()) 374 ts := s.(*TestServer) 375 376 sessionUsername := "testUser" 377 id, origSecret, err := ts.authentication.newAuthSession(context.Background(), sessionUsername) 378 if err != nil { 379 t.Fatal(err) 380 } 381 382 for _, tc := range []struct { 383 testname string 384 cookie serverpb.SessionCookie 385 shouldVerify bool 386 }{ 387 { 388 testname: "Valid cookie", 389 cookie: serverpb.SessionCookie{ 390 ID: id, 391 Secret: origSecret, 392 }, 393 shouldVerify: true, 394 }, 395 { 396 testname: "No secret", 397 cookie: serverpb.SessionCookie{ 398 ID: id, 399 }, 400 shouldVerify: false, 401 }, 402 { 403 testname: "Wrong secret", 404 cookie: serverpb.SessionCookie{ 405 ID: id, 406 Secret: []byte{0x01, 0x02, 0x03, 0x04}, 407 }, 408 shouldVerify: false, 409 }, 410 { 411 testname: "No ID", 412 cookie: serverpb.SessionCookie{ 413 Secret: origSecret, 414 }, 415 shouldVerify: false, 416 }, 417 { 418 testname: "Wrong ID", 419 cookie: serverpb.SessionCookie{ 420 ID: 123456, 421 Secret: origSecret, 422 }, 423 shouldVerify: false, 424 }, 425 { 426 testname: "Empty cookie", 427 cookie: serverpb.SessionCookie{}, 428 shouldVerify: false, 429 }, 430 } { 431 t.Run(tc.testname, func(t *testing.T) { 432 valid, username, err := ts.authentication.verifySession(context.Background(), &tc.cookie) 433 if err != nil { 434 t.Fatalf("test got error %s, wanted no error", err) 435 } 436 if a, e := valid, tc.shouldVerify; a != e { 437 t.Fatalf("cookie %v verification = %t, wanted %t", tc.cookie, a, e) 438 } 439 if a, e := username, sessionUsername; tc.shouldVerify && a != e { 440 t.Fatalf("cookie %v verification returned username %s, wanted %s", tc.cookie, a, e) 441 } 442 }) 443 } 444 } 445 446 func TestAuthenticationAPIUserLogin(t *testing.T) { 447 defer leaktest.AfterTest(t)() 448 s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) 449 defer s.Stopper().Stop(context.Background()) 450 ts := s.(*TestServer) 451 452 const ( 453 validUsername = "testuser" 454 validPassword = "password" 455 ) 456 457 cmd := fmt.Sprintf("CREATE USER %s WITH PASSWORD '%s'", validUsername, validPassword) 458 if _, err := db.Exec(cmd); err != nil { 459 t.Fatalf("failed to create user: %s", err) 460 } 461 462 tryLogin := func(username, password string) (*http.Response, error) { 463 // We need to instantiate our own HTTP Request, because we must inspect 464 // the returned headers. 465 httpClient, err := ts.GetHTTPClient() 466 if util.RaceEnabled { 467 httpClient.Timeout += 30 * time.Second 468 } 469 if err != nil { 470 t.Fatalf("could not get HTTP client: %s", err) 471 } 472 req := serverpb.UserLoginRequest{ 473 Username: username, 474 Password: password, 475 } 476 var resp serverpb.UserLoginResponse 477 return httputil.PostJSONWithRequest( 478 httpClient, ts.AdminURL()+loginPath, &req, &resp, 479 ) 480 } 481 482 // Unsuccessful attempt. Should come back with a 401 and no "Set-Cookie" 483 { 484 response, err := tryLogin(validUsername, "wrongpassword") 485 if !testutils.IsError(err, "status: 401") { 486 t.Fatalf("login got error %s, wanted error with 401 status", err) 487 } 488 if cookies := response.Cookies(); len(cookies) > 0 { 489 t.Fatalf("bad login got cookies %v, wanted empty", cookies) 490 } 491 } 492 493 // Successful attempt. Should succeed and return a Set-Cookie header. 494 response, err := tryLogin(validUsername, validPassword) 495 if err != nil { 496 t.Fatalf("good login got error %s, wanted no error", err) 497 } 498 cookies := response.Cookies() 499 if len(cookies) == 0 { 500 t.Fatalf("good login got no cookies: %v", response) 501 } 502 503 sessionCookie, err := decodeSessionCookie(cookies[0]) 504 if err != nil { 505 t.Fatalf("failed to decode session cookie: %s", err) 506 } 507 508 // Look up session in database and verify hashed secret value and username. 509 query := `SELECT "hashedSecret", "username" FROM system.web_sessions WHERE id = $1` 510 result := db.QueryRow(query, sessionCookie.ID) 511 var ( 512 sessHashedSecret []byte 513 sessUsername string 514 ) 515 if err := result.Scan(&sessHashedSecret, &sessUsername); err != nil { 516 t.Fatalf("error querying auth session: %s", err) 517 } 518 519 if a, e := sessUsername, validUsername; a != e { 520 t.Fatalf("created auth session had username %s, wanted %s", a, e) 521 } 522 523 hasher := sha256.New() 524 _, _ = hasher.Write(sessionCookie.Secret) 525 hashedSecret := hasher.Sum(nil) 526 if a, e := sessHashedSecret, hashedSecret; !bytes.Equal(a, e) { 527 t.Fatalf( 528 "session secret hash was %v, wanted %v (derived from original secret %v)", 529 a, 530 e, 531 sessionCookie.Secret, 532 ) 533 } 534 } 535 536 func TestLogout(t *testing.T) { 537 defer leaktest.AfterTest(t)() 538 s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) 539 defer s.Stopper().Stop(context.Background()) 540 ts := s.(*TestServer) 541 542 // Log in. 543 authHTTPClient, cookie, err := ts.getAuthenticatedHTTPClientAndCookie(authenticatedUserName, true) 544 if err != nil { 545 t.Fatal("error opening HTTP client", err) 546 } 547 548 // Log out. 549 var resp serverpb.UserLogoutResponse 550 if err := httputil.GetJSON(authHTTPClient, ts.AdminURL()+logoutPath, &resp); err != nil { 551 t.Fatal("logout request failed:", err) 552 } 553 554 // Verify that revokedAt has been set in the DB. 555 query := `SELECT "revokedAt" FROM system.web_sessions WHERE id = $1` 556 result := db.QueryRow(query, cookie.ID) 557 var revokedAt string 558 if err := result.Scan(&revokedAt); err != nil { 559 t.Fatalf("error querying auth session: %s", err) 560 } 561 562 if revokedAt == "" { 563 t.Fatal("expected revoked at to not be empty; was empty") 564 } 565 566 databasesURL := ts.AdminURL() + "/_admin/v1/databases" 567 568 // Verify that we're unauthorized after logout. 569 response, err := authHTTPClient.Get(databasesURL) 570 if err != nil { 571 t.Fatal(err) 572 } 573 defer response.Body.Close() 574 575 if response.StatusCode != http.StatusUnauthorized { 576 t.Fatal("expected unauthorized response after logout; got", response.StatusCode) 577 } 578 579 // Try to use the revoked cookie; verify that it doesn't work. 580 parsedURL, err := url.Parse(s.AdminURL()) 581 if err != nil { 582 t.Fatal(err) 583 } 584 encodedCookie, err := EncodeSessionCookie(cookie, false /* forHTTPSOnly */) 585 if err != nil { 586 t.Fatal(err) 587 } 588 589 invalidAuthClient, err := s.GetHTTPClient() 590 if err != nil { 591 t.Fatal(err) 592 } 593 jar, err := cookiejar.New(nil) 594 if err != nil { 595 t.Fatal(err) 596 } 597 invalidAuthClient.Jar = jar 598 invalidAuthClient.Jar.SetCookies(parsedURL, []*http.Cookie{encodedCookie}) 599 600 invalidAuthResp, err := invalidAuthClient.Get(databasesURL) 601 if err != nil { 602 t.Fatal(err) 603 } 604 defer invalidAuthResp.Body.Close() 605 606 if invalidAuthResp.StatusCode != 401 { 607 t.Fatal("expected unauthorized error; got", invalidAuthResp.StatusCode) 608 } 609 } 610 611 // TestAuthenticationMux verifies that the authentication handler is used by all 612 // of the APIs it should be protecting. Authentication is enabled by default for 613 // the test server, and every test which accesses APIs uses an authenticated 614 // client (except for a few that specifically override it). Therefore, this 615 // test verifies that authentication mux is attached to services at all by 616 // testing an endpoint of each with a verified and unverified client. 617 func TestAuthenticationMux(t *testing.T) { 618 defer leaktest.AfterTest(t)() 619 s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) 620 defer s.Stopper().Stop(context.Background()) 621 tsrv := s.(*TestServer) 622 623 // Both the normal and authenticated client will be used for each test. 624 normalClient, err := tsrv.GetHTTPClient() 625 if err != nil { 626 t.Fatal(err) 627 } 628 authClient, err := tsrv.GetAdminAuthenticatedHTTPClient() 629 if err != nil { 630 t.Fatal(err) 631 } 632 633 runRequest := func( 634 client http.Client, method string, path string, body []byte, expected int, 635 ) error { 636 req, err := http.NewRequest(method, tsrv.AdminURL()+path, bytes.NewBuffer(body)) 637 if err != nil { 638 return err 639 } 640 resp, err := client.Do(req) 641 if err != nil { 642 return err 643 } 644 defer resp.Body.Close() 645 if a, e := resp.StatusCode, expected; a != e { 646 message, err := ioutil.ReadAll(resp.Body) 647 if err != nil { 648 message = []byte(err.Error()) 649 } 650 return errors.Errorf("got status code %d (msg %s), wanted %d", a, string(message), e) 651 } 652 return nil 653 } 654 655 // Generate request for time series API. 656 tsReq := tspb.TimeSeriesQueryRequest{ 657 StartNanos: 0, 658 EndNanos: 100 * 1e9, 659 Queries: []tspb.Query{{Name: "test.metric"}}, 660 } 661 var tsReqBuffer bytes.Buffer 662 marshalFn := (&jsonpb.Marshaler{}).Marshal 663 if err := marshalFn(&tsReqBuffer, &tsReq); err != nil { 664 t.Fatal(err) 665 } 666 667 for _, tc := range []struct { 668 method string 669 path string 670 body []byte 671 }{ 672 {"GET", adminPrefix + "users", nil}, 673 {"GET", statusPrefix + "sessions", nil}, 674 {"POST", ts.URLPrefix + "query", tsReqBuffer.Bytes()}, 675 } { 676 t.Run("path="+tc.path, func(t *testing.T) { 677 // Verify normal client returns 401 Unauthorized. 678 if err := runRequest(normalClient, tc.method, tc.path, tc.body, http.StatusUnauthorized); err != nil { 679 t.Fatalf("request %s failed when not authorized: %s", tc.path, err) 680 } 681 682 // Verify authenticated client returns 200 OK. 683 if err := runRequest(authClient, tc.method, tc.path, tc.body, http.StatusOK); err != nil { 684 t.Fatalf("request %s failed when authorized: %s", tc.path, err) 685 } 686 }) 687 } 688 } 689 690 func TestGRPCAuthentication(t *testing.T) { 691 defer leaktest.AfterTest(t)() 692 693 ctx := context.Background() 694 s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) 695 defer s.Stopper().Stop(ctx) 696 697 // For each subsystem we pick a representative RPC. The idea is not to 698 // exhaustively test each RPC but to prevent server startup from being 699 // refactored in such a way that an entire subsystem becomes inadvertently 700 // exempt from authentication checks. 701 subsystems := []struct { 702 name string 703 sendRPC func(context.Context, *grpc.ClientConn) error 704 }{ 705 {"gossip", func(ctx context.Context, conn *grpc.ClientConn) error { 706 stream, err := gossip.NewGossipClient(conn).Gossip(ctx) 707 if err != nil { 708 return err 709 } 710 _ = stream.Send(&gossip.Request{}) 711 _, err = stream.Recv() 712 return err 713 }}, 714 {"internal", func(ctx context.Context, conn *grpc.ClientConn) error { 715 _, err := roachpb.NewInternalClient(conn).Batch(ctx, &roachpb.BatchRequest{}) 716 return err 717 }}, 718 {"perReplica", func(ctx context.Context, conn *grpc.ClientConn) error { 719 _, err := kvserver.NewPerReplicaClient(conn).CollectChecksum(ctx, &kvserver.CollectChecksumRequest{}) 720 return err 721 }}, 722 {"raft", func(ctx context.Context, conn *grpc.ClientConn) error { 723 stream, err := kvserver.NewMultiRaftClient(conn).RaftMessageBatch(ctx) 724 if err != nil { 725 return err 726 } 727 _ = stream.Send(&kvserver.RaftMessageRequestBatch{}) 728 _, err = stream.Recv() 729 return err 730 }}, 731 {"closedTimestamp", func(ctx context.Context, conn *grpc.ClientConn) error { 732 stream, err := ctpb.NewClosedTimestampClient(conn).Get(ctx) 733 if err != nil { 734 return err 735 } 736 _ = stream.Send(&ctpb.Reaction{}) 737 _, err = stream.Recv() 738 return err 739 }}, 740 {"distSQL", func(ctx context.Context, conn *grpc.ClientConn) error { 741 stream, err := execinfrapb.NewDistSQLClient(conn).RunSyncFlow(ctx) 742 if err != nil { 743 return err 744 } 745 _ = stream.Send(&execinfrapb.ConsumerSignal{}) 746 _, err = stream.Recv() 747 return err 748 }}, 749 {"init", func(ctx context.Context, conn *grpc.ClientConn) error { 750 _, err := serverpb.NewInitClient(conn).Bootstrap(ctx, &serverpb.BootstrapRequest{}) 751 return err 752 }}, 753 {"admin", func(ctx context.Context, conn *grpc.ClientConn) error { 754 _, err := serverpb.NewAdminClient(conn).Databases(ctx, &serverpb.DatabasesRequest{}) 755 return err 756 }}, 757 {"status", func(ctx context.Context, conn *grpc.ClientConn) error { 758 _, err := serverpb.NewStatusClient(conn).ListSessions(ctx, &serverpb.ListSessionsRequest{}) 759 return err 760 }}, 761 } 762 763 conn, err := grpc.DialContext(ctx, s.ServingRPCAddr(), 764 grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ 765 InsecureSkipVerify: true, 766 }))) 767 if err != nil { 768 t.Fatal(err) 769 } 770 defer func(conn *grpc.ClientConn) { _ = conn.Close() }(conn) 771 for _, subsystem := range subsystems { 772 t.Run(fmt.Sprintf("no-cert/%s", subsystem.name), func(t *testing.T) { 773 err := subsystem.sendRPC(ctx, conn) 774 if exp := "no client certificates in request"; !testutils.IsError(err, exp) { 775 t.Errorf("expected %q error, but got %v", exp, err) 776 } 777 }) 778 } 779 780 certManager, err := s.RPCContext().GetCertificateManager() 781 if err != nil { 782 t.Fatal(err) 783 } 784 tlsConfig, err := certManager.GetClientTLSConfig("testuser") 785 if err != nil { 786 t.Fatal(err) 787 } 788 conn, err = grpc.DialContext(ctx, s.ServingRPCAddr(), 789 grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) 790 if err != nil { 791 t.Fatal(err) 792 } 793 defer func(conn *grpc.ClientConn) { _ = conn.Close() }(conn) 794 for _, subsystem := range subsystems { 795 t.Run(fmt.Sprintf("bad-user/%s", subsystem.name), func(t *testing.T) { 796 err := subsystem.sendRPC(ctx, conn) 797 if exp := `user \[testuser\] is not allowed to perform this RPC`; !testutils.IsError(err, exp) { 798 t.Errorf("expected %q error, but got %v", exp, err) 799 } 800 }) 801 } 802 }