github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/server/authentication_test.go (about)

     1  // Copyright 2015 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package server
    12  
    13  import (
    14  	"bytes"
    15  	"context"
    16  	"crypto/sha256"
    17  	"crypto/tls"
    18  	gosql "database/sql"
    19  	"fmt"
    20  	"io/ioutil"
    21  	"net/http"
    22  	"net/http/cookiejar"
    23  	"net/url"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/cockroachdb/cockroach/pkg/base"
    28  	"github.com/cockroachdb/cockroach/pkg/gossip"
    29  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver"
    30  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver/closedts/ctpb"
    31  	"github.com/cockroachdb/cockroach/pkg/roachpb"
    32  	"github.com/cockroachdb/cockroach/pkg/security"
    33  	"github.com/cockroachdb/cockroach/pkg/server/debug"
    34  	"github.com/cockroachdb/cockroach/pkg/server/serverpb"
    35  	"github.com/cockroachdb/cockroach/pkg/sql/execinfrapb"
    36  	"github.com/cockroachdb/cockroach/pkg/testutils"
    37  	"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
    38  	"github.com/cockroachdb/cockroach/pkg/ts"
    39  	"github.com/cockroachdb/cockroach/pkg/ts/tspb"
    40  	"github.com/cockroachdb/cockroach/pkg/util"
    41  	"github.com/cockroachdb/cockroach/pkg/util/httputil"
    42  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    43  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    44  	"github.com/cockroachdb/errors"
    45  	"github.com/gogo/protobuf/jsonpb"
    46  	"github.com/lib/pq"
    47  	"golang.org/x/crypto/bcrypt"
    48  	"google.golang.org/grpc"
    49  	"google.golang.org/grpc/credentials"
    50  )
    51  
    52  type ctxI interface {
    53  	GetHTTPClient() (http.Client, error)
    54  	HTTPRequestScheme() string
    55  }
    56  
    57  var _ ctxI = insecureCtx{}
    58  var _ ctxI = (*base.Config)(nil)
    59  
    60  type insecureCtx struct{}
    61  
    62  func (insecureCtx) GetHTTPClient() (http.Client, error) {
    63  	return http.Client{
    64  		Transport: &http.Transport{
    65  			TLSClientConfig: &tls.Config{
    66  				InsecureSkipVerify: true,
    67  			},
    68  		},
    69  	}, nil
    70  }
    71  
    72  func (insecureCtx) HTTPRequestScheme() string {
    73  	return "https"
    74  }
    75  
    76  // Verify client certificate enforcement and user whitelisting.
    77  func TestSSLEnforcement(t *testing.T) {
    78  	defer leaktest.AfterTest(t)()
    79  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{
    80  		// This test is verifying the (unimplemented) authentication of SSL
    81  		// client certificates over HTTP endpoints. Web session authentication
    82  		// is disabled in order to avoid the need to authenticate the individual
    83  		// clients being instantiated.
    84  		DisableWebSessionAuthentication: true,
    85  	})
    86  	defer s.Stopper().Stop(context.Background())
    87  
    88  	// HTTPS with client certs for security.RootUser.
    89  	rootCertsContext := testutils.NewTestBaseContext(security.RootUser)
    90  	// HTTPS with client certs for security.NodeUser.
    91  	nodeCertsContext := testutils.NewNodeTestBaseContext()
    92  	// HTTPS with client certs for TestUser.
    93  	testCertsContext := testutils.NewTestBaseContext(TestUser)
    94  	// HTTPS without client certs. The user does not matter.
    95  	noCertsContext := insecureCtx{}
    96  	// Plain http.
    97  	insecureContext := testutils.NewTestBaseContext(TestUser)
    98  	insecureContext.Insecure = true
    99  
   100  	kvGet := &roachpb.GetRequest{}
   101  	kvGet.Key = roachpb.Key("/")
   102  
   103  	for _, tc := range []struct {
   104  		path string
   105  		ctx  ctxI
   106  		code int // http response code
   107  	}{
   108  		// Health endpoint is special-cased; allowed to serve on HTTP.
   109  		{"/health", insecureContext, http.StatusOK},
   110  
   111  		// /ui/: basic file server: no auth.
   112  		{"", rootCertsContext, http.StatusOK},
   113  		{"", nodeCertsContext, http.StatusOK},
   114  		{"", testCertsContext, http.StatusOK},
   115  		{"", noCertsContext, http.StatusOK},
   116  		{"", insecureContext, http.StatusTemporaryRedirect},
   117  
   118  		// /_admin/: server.adminServer: no auth.
   119  		{adminPrefix + "health", rootCertsContext, http.StatusOK},
   120  		{adminPrefix + "health", nodeCertsContext, http.StatusOK},
   121  		{adminPrefix + "health", testCertsContext, http.StatusOK},
   122  		{adminPrefix + "health", noCertsContext, http.StatusOK},
   123  		{adminPrefix + "health", insecureContext, http.StatusTemporaryRedirect},
   124  
   125  		// /debug/: server.adminServer: no auth.
   126  		{debug.Endpoint + "vars", rootCertsContext, http.StatusOK},
   127  		{debug.Endpoint + "vars", nodeCertsContext, http.StatusOK},
   128  		{debug.Endpoint + "vars", testCertsContext, http.StatusOK},
   129  		{debug.Endpoint + "vars", noCertsContext, http.StatusOK},
   130  		{debug.Endpoint + "vars", insecureContext, http.StatusTemporaryRedirect},
   131  
   132  		// /_status/nodes: server.statusServer: no auth.
   133  		{statusPrefix + "nodes", rootCertsContext, http.StatusOK},
   134  		{statusPrefix + "nodes", nodeCertsContext, http.StatusOK},
   135  		{statusPrefix + "nodes", testCertsContext, http.StatusOK},
   136  		{statusPrefix + "nodes", noCertsContext, http.StatusOK},
   137  		{statusPrefix + "nodes", insecureContext, http.StatusTemporaryRedirect},
   138  
   139  		// /ts/: ts.Server: no auth.
   140  		{ts.URLPrefix, rootCertsContext, http.StatusNotFound},
   141  		{ts.URLPrefix, nodeCertsContext, http.StatusNotFound},
   142  		{ts.URLPrefix, testCertsContext, http.StatusNotFound},
   143  		{ts.URLPrefix, noCertsContext, http.StatusNotFound},
   144  		{ts.URLPrefix, insecureContext, http.StatusTemporaryRedirect},
   145  	} {
   146  		t.Run("", func(t *testing.T) {
   147  			client, err := tc.ctx.GetHTTPClient()
   148  			if err != nil {
   149  				t.Fatal(err)
   150  			}
   151  			// Avoid automatically following redirects.
   152  			client.CheckRedirect = func(*http.Request, []*http.Request) error {
   153  				return http.ErrUseLastResponse
   154  			}
   155  			url := url.URL{
   156  				Scheme: tc.ctx.HTTPRequestScheme(),
   157  				Host:   s.(*TestServer).Cfg.HTTPAddr,
   158  				Path:   tc.path,
   159  			}
   160  			resp, err := client.Get(url.String())
   161  			if err != nil {
   162  				t.Fatal(err)
   163  			}
   164  
   165  			defer resp.Body.Close()
   166  			if resp.StatusCode != tc.code {
   167  				t.Errorf("expected status code %d, got %d", tc.code, resp.StatusCode)
   168  				u, err := resp.Location()
   169  				t.Errorf("orig=%s url=%s err=%v", tc.path, u, err)
   170  			}
   171  		})
   172  	}
   173  }
   174  
   175  func TestVerifyPassword(t *testing.T) {
   176  	defer leaktest.AfterTest(t)()
   177  
   178  	ctx := context.Background()
   179  	s, db, _ := serverutils.StartServer(t, base.TestServerArgs{})
   180  	defer s.Stopper().Stop(ctx)
   181  
   182  	ts := s.(*TestServer)
   183  
   184  	if util.RaceEnabled {
   185  		// The default bcrypt cost makes this test approximately 30s slower when the
   186  		// race detector is on.
   187  		defer func(prev int) { security.BcryptCost = prev }(security.BcryptCost)
   188  		security.BcryptCost = bcrypt.MinCost
   189  	}
   190  
   191  	//location is used for timezone testing.
   192  	shanghaiLoc, err := time.LoadLocation("Asia/Shanghai")
   193  	if err != nil {
   194  		t.Fatal(err)
   195  	}
   196  
   197  	for _, user := range []struct {
   198  		username         string
   199  		password         string
   200  		loginFlag        string
   201  		validUntilClause string
   202  		qargs            []interface{}
   203  	}{
   204  		{"azure_diamond", "hunter2", "", "", nil},
   205  		{"druidia", "12345", "", "", nil},
   206  
   207  		{"richardc", "12345", "NOLOGIN", "", nil},
   208  		{"before_epoch", "12345", "", "VALID UNTIL '1969-01-01'", nil},
   209  		{"epoch", "12345", "", "VALID UNTIL '1970-01-01'", nil},
   210  		{"cockroach", "12345", "", "VALID UNTIL '2100-01-01'", nil},
   211  		{"cthon98", "12345", "", "VALID UNTIL NULL", nil},
   212  
   213  		{"toolate", "12345", "", "VALID UNTIL $1",
   214  			[]interface{}{timeutil.Now().Add(-10 * time.Minute)}},
   215  		{"timelord", "12345", "", "VALID UNTIL $1",
   216  			[]interface{}{timeutil.Now().Add(59 * time.Minute).In(shanghaiLoc)}},
   217  	} {
   218  		cmd := fmt.Sprintf(
   219  			"CREATE USER %s WITH PASSWORD '%s' %s %s",
   220  			user.username, user.password, user.loginFlag, user.validUntilClause)
   221  
   222  		if _, err := db.Exec(cmd, user.qargs...); err != nil {
   223  			t.Fatalf("failed to create user: %s", err)
   224  		}
   225  	}
   226  
   227  	for _, tc := range []struct {
   228  		username           string
   229  		password           string
   230  		shouldAuthenticate bool
   231  		expectedErrString  string
   232  	}{
   233  		{"azure_diamond", "hunter2", true, ""},
   234  		{"azure_diamond", "hunter", false, "crypto/bcrypt"},
   235  		{"azure_diamond", "", false, "crypto/bcrypt"},
   236  		{"azure_diamond", "🍦", false, "crypto/bcrypt"},
   237  		{"azure_diamond", "hunter2345", false, "crypto/bcrypt"},
   238  		{"azure_diamond", "shunter2", false, "crypto/bcrypt"},
   239  		{"azure_diamond", "12345", false, "crypto/bcrypt"},
   240  		{"azure_diamond", "*******", false, "crypto/bcrypt"},
   241  		{"druidia", "12345", true, ""},
   242  		{"druidia", "hunter2", false, "crypto/bcrypt"},
   243  		{"root", "", false, "crypto/bcrypt"},
   244  		{"", "", false, "does not exist"},
   245  		{"doesntexist", "zxcvbn", false, "does not exist"},
   246  
   247  		{"richardc", "12345", false,
   248  			"richardc does not have login privilege"},
   249  		{"before_epoch", "12345", false, ""},
   250  		{"epoch", "12345", false, ""},
   251  		{"cockroach", "12345", true, ""},
   252  		{"toolate", "12345", false, ""},
   253  		{"timelord", "12345", true, ""},
   254  		{"cthon98", "12345", true, ""},
   255  	} {
   256  		t.Run("", func(t *testing.T) {
   257  			valid, expired, err := ts.authentication.verifyPassword(context.Background(), tc.username, tc.password)
   258  			if err != nil {
   259  				t.Errorf(
   260  					"credentials %s/%s failed with error %s, wanted no error",
   261  					tc.username,
   262  					tc.password,
   263  					err,
   264  				)
   265  			}
   266  			if valid && !expired != tc.shouldAuthenticate {
   267  				t.Errorf(
   268  					"credentials %s/%s valid = %t, wanted %t",
   269  					tc.username,
   270  					tc.password,
   271  					valid,
   272  					tc.shouldAuthenticate,
   273  				)
   274  			}
   275  		})
   276  	}
   277  }
   278  
   279  func TestCreateSession(t *testing.T) {
   280  	defer leaktest.AfterTest(t)()
   281  	s, db, _ := serverutils.StartServer(t, base.TestServerArgs{})
   282  	defer s.Stopper().Stop(context.Background())
   283  	ts := s.(*TestServer)
   284  
   285  	username := "testUser"
   286  
   287  	// Create an authentication, noting the time before and after creation. This
   288  	// lets us ensure that the timestamps created are accurate.
   289  	timeBoundBefore := ts.clock.PhysicalTime()
   290  	id, origSecret, err := ts.authentication.newAuthSession(context.Background(), username)
   291  	if err != nil {
   292  		t.Fatalf("error creating auth session: %s", err)
   293  	}
   294  	timeBoundAfter := ts.clock.PhysicalTime()
   295  
   296  	// Query fields from created session.
   297  	query := `
   298  SELECT "hashedSecret", "username", "createdAt", "lastUsedAt", "expiresAt", "revokedAt", "auditInfo"
   299  FROM system.web_sessions
   300  WHERE id = $1`
   301  
   302  	result := db.QueryRow(query, id)
   303  	var (
   304  		sessHashedSecret []byte
   305  		sessUsername     string
   306  		sessCreated      time.Time
   307  		sessLastUsed     time.Time
   308  		sessExpires      time.Time
   309  		sessRevoked      pq.NullTime
   310  		sessAuditInfo    gosql.NullString
   311  	)
   312  	if err := result.Scan(
   313  		&sessHashedSecret,
   314  		&sessUsername,
   315  		&sessCreated,
   316  		&sessLastUsed,
   317  		&sessExpires,
   318  		&sessRevoked,
   319  		&sessAuditInfo,
   320  	); err != nil {
   321  		t.Fatalf("error querying created auth session: %s", err)
   322  	}
   323  
   324  	// Verify hashed secret matches original secret
   325  	hasher := sha256.New()
   326  	_, _ = hasher.Write(origSecret)
   327  	hashedSecret := hasher.Sum(nil)
   328  	if !bytes.Equal(sessHashedSecret, hashedSecret) {
   329  		t.Fatalf("hashed value of secret: \n%#v\ncomputed as: \n%#v\nwanted: \n%#v", origSecret, hashedSecret, sessHashedSecret)
   330  	}
   331  
   332  	// Username.
   333  	if a, e := sessUsername, username; a != e {
   334  		t.Fatalf("session username got %s, wanted %s", a, e)
   335  	}
   336  
   337  	// Timestamps.
   338  	verifyTimestamp := func(actual time.Time, early time.Time, late time.Time) error {
   339  		if actual.Before(early) {
   340  			return errors.Errorf("time %s was before early bound %s", actual, early)
   341  		}
   342  		if late.Before(actual) {
   343  			return errors.Errorf("time %s was after late bound %s", actual, late)
   344  		}
   345  		return nil
   346  	}
   347  
   348  	if err := verifyTimestamp(sessCreated, timeBoundBefore, timeBoundAfter); err != nil {
   349  		t.Fatalf("bad createdAt timestamp: %s", err)
   350  	}
   351  	if err := verifyTimestamp(sessLastUsed, timeBoundBefore, timeBoundAfter); err != nil {
   352  		t.Fatalf("bad lastUsedAt timestamp: %s", err)
   353  	}
   354  	timeout := webSessionTimeout.Get(&s.ClusterSettings().SV)
   355  	if err := verifyTimestamp(
   356  		sessExpires, timeBoundBefore.Add(timeout), timeBoundAfter.Add(timeout),
   357  	); err != nil {
   358  		t.Fatalf("bad expiresAt timestamp: %s", err)
   359  	}
   360  
   361  	// Null fields
   362  	if sessRevoked.Valid {
   363  		t.Fatalf("sess had revokedAt timestamp %s, wanted null", sessRevoked.Time)
   364  	}
   365  	if sessAuditInfo.Valid {
   366  		t.Fatalf("sess had auditInfo %s, wanted null", sessAuditInfo.String)
   367  	}
   368  }
   369  
   370  func TestVerifySession(t *testing.T) {
   371  	defer leaktest.AfterTest(t)()
   372  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
   373  	defer s.Stopper().Stop(context.Background())
   374  	ts := s.(*TestServer)
   375  
   376  	sessionUsername := "testUser"
   377  	id, origSecret, err := ts.authentication.newAuthSession(context.Background(), sessionUsername)
   378  	if err != nil {
   379  		t.Fatal(err)
   380  	}
   381  
   382  	for _, tc := range []struct {
   383  		testname     string
   384  		cookie       serverpb.SessionCookie
   385  		shouldVerify bool
   386  	}{
   387  		{
   388  			testname: "Valid cookie",
   389  			cookie: serverpb.SessionCookie{
   390  				ID:     id,
   391  				Secret: origSecret,
   392  			},
   393  			shouldVerify: true,
   394  		},
   395  		{
   396  			testname: "No secret",
   397  			cookie: serverpb.SessionCookie{
   398  				ID: id,
   399  			},
   400  			shouldVerify: false,
   401  		},
   402  		{
   403  			testname: "Wrong secret",
   404  			cookie: serverpb.SessionCookie{
   405  				ID:     id,
   406  				Secret: []byte{0x01, 0x02, 0x03, 0x04},
   407  			},
   408  			shouldVerify: false,
   409  		},
   410  		{
   411  			testname: "No ID",
   412  			cookie: serverpb.SessionCookie{
   413  				Secret: origSecret,
   414  			},
   415  			shouldVerify: false,
   416  		},
   417  		{
   418  			testname: "Wrong ID",
   419  			cookie: serverpb.SessionCookie{
   420  				ID:     123456,
   421  				Secret: origSecret,
   422  			},
   423  			shouldVerify: false,
   424  		},
   425  		{
   426  			testname:     "Empty cookie",
   427  			cookie:       serverpb.SessionCookie{},
   428  			shouldVerify: false,
   429  		},
   430  	} {
   431  		t.Run(tc.testname, func(t *testing.T) {
   432  			valid, username, err := ts.authentication.verifySession(context.Background(), &tc.cookie)
   433  			if err != nil {
   434  				t.Fatalf("test got error %s, wanted no error", err)
   435  			}
   436  			if a, e := valid, tc.shouldVerify; a != e {
   437  				t.Fatalf("cookie %v verification = %t, wanted %t", tc.cookie, a, e)
   438  			}
   439  			if a, e := username, sessionUsername; tc.shouldVerify && a != e {
   440  				t.Fatalf("cookie %v verification returned username %s, wanted %s", tc.cookie, a, e)
   441  			}
   442  		})
   443  	}
   444  }
   445  
   446  func TestAuthenticationAPIUserLogin(t *testing.T) {
   447  	defer leaktest.AfterTest(t)()
   448  	s, db, _ := serverutils.StartServer(t, base.TestServerArgs{})
   449  	defer s.Stopper().Stop(context.Background())
   450  	ts := s.(*TestServer)
   451  
   452  	const (
   453  		validUsername = "testuser"
   454  		validPassword = "password"
   455  	)
   456  
   457  	cmd := fmt.Sprintf("CREATE USER %s WITH PASSWORD '%s'", validUsername, validPassword)
   458  	if _, err := db.Exec(cmd); err != nil {
   459  		t.Fatalf("failed to create user: %s", err)
   460  	}
   461  
   462  	tryLogin := func(username, password string) (*http.Response, error) {
   463  		// We need to instantiate our own HTTP Request, because we must inspect
   464  		// the returned headers.
   465  		httpClient, err := ts.GetHTTPClient()
   466  		if util.RaceEnabled {
   467  			httpClient.Timeout += 30 * time.Second
   468  		}
   469  		if err != nil {
   470  			t.Fatalf("could not get HTTP client: %s", err)
   471  		}
   472  		req := serverpb.UserLoginRequest{
   473  			Username: username,
   474  			Password: password,
   475  		}
   476  		var resp serverpb.UserLoginResponse
   477  		return httputil.PostJSONWithRequest(
   478  			httpClient, ts.AdminURL()+loginPath, &req, &resp,
   479  		)
   480  	}
   481  
   482  	// Unsuccessful attempt. Should come back with a 401 and no "Set-Cookie"
   483  	{
   484  		response, err := tryLogin(validUsername, "wrongpassword")
   485  		if !testutils.IsError(err, "status: 401") {
   486  			t.Fatalf("login got error %s, wanted error with 401 status", err)
   487  		}
   488  		if cookies := response.Cookies(); len(cookies) > 0 {
   489  			t.Fatalf("bad login got cookies %v, wanted empty", cookies)
   490  		}
   491  	}
   492  
   493  	// Successful attempt. Should succeed and return a Set-Cookie header.
   494  	response, err := tryLogin(validUsername, validPassword)
   495  	if err != nil {
   496  		t.Fatalf("good login got error %s, wanted no error", err)
   497  	}
   498  	cookies := response.Cookies()
   499  	if len(cookies) == 0 {
   500  		t.Fatalf("good login got no cookies: %v", response)
   501  	}
   502  
   503  	sessionCookie, err := decodeSessionCookie(cookies[0])
   504  	if err != nil {
   505  		t.Fatalf("failed to decode session cookie: %s", err)
   506  	}
   507  
   508  	// Look up session in database and verify hashed secret value and username.
   509  	query := `SELECT "hashedSecret", "username" FROM system.web_sessions WHERE id = $1`
   510  	result := db.QueryRow(query, sessionCookie.ID)
   511  	var (
   512  		sessHashedSecret []byte
   513  		sessUsername     string
   514  	)
   515  	if err := result.Scan(&sessHashedSecret, &sessUsername); err != nil {
   516  		t.Fatalf("error querying auth session: %s", err)
   517  	}
   518  
   519  	if a, e := sessUsername, validUsername; a != e {
   520  		t.Fatalf("created auth session had username %s, wanted %s", a, e)
   521  	}
   522  
   523  	hasher := sha256.New()
   524  	_, _ = hasher.Write(sessionCookie.Secret)
   525  	hashedSecret := hasher.Sum(nil)
   526  	if a, e := sessHashedSecret, hashedSecret; !bytes.Equal(a, e) {
   527  		t.Fatalf(
   528  			"session secret hash was %v, wanted %v (derived from original secret %v)",
   529  			a,
   530  			e,
   531  			sessionCookie.Secret,
   532  		)
   533  	}
   534  }
   535  
   536  func TestLogout(t *testing.T) {
   537  	defer leaktest.AfterTest(t)()
   538  	s, db, _ := serverutils.StartServer(t, base.TestServerArgs{})
   539  	defer s.Stopper().Stop(context.Background())
   540  	ts := s.(*TestServer)
   541  
   542  	// Log in.
   543  	authHTTPClient, cookie, err := ts.getAuthenticatedHTTPClientAndCookie(authenticatedUserName, true)
   544  	if err != nil {
   545  		t.Fatal("error opening HTTP client", err)
   546  	}
   547  
   548  	// Log out.
   549  	var resp serverpb.UserLogoutResponse
   550  	if err := httputil.GetJSON(authHTTPClient, ts.AdminURL()+logoutPath, &resp); err != nil {
   551  		t.Fatal("logout request failed:", err)
   552  	}
   553  
   554  	// Verify that revokedAt has been set in the DB.
   555  	query := `SELECT "revokedAt" FROM system.web_sessions WHERE id = $1`
   556  	result := db.QueryRow(query, cookie.ID)
   557  	var revokedAt string
   558  	if err := result.Scan(&revokedAt); err != nil {
   559  		t.Fatalf("error querying auth session: %s", err)
   560  	}
   561  
   562  	if revokedAt == "" {
   563  		t.Fatal("expected revoked at to not be empty; was empty")
   564  	}
   565  
   566  	databasesURL := ts.AdminURL() + "/_admin/v1/databases"
   567  
   568  	// Verify that we're unauthorized after logout.
   569  	response, err := authHTTPClient.Get(databasesURL)
   570  	if err != nil {
   571  		t.Fatal(err)
   572  	}
   573  	defer response.Body.Close()
   574  
   575  	if response.StatusCode != http.StatusUnauthorized {
   576  		t.Fatal("expected unauthorized response after logout; got", response.StatusCode)
   577  	}
   578  
   579  	// Try to use the revoked cookie; verify that it doesn't work.
   580  	parsedURL, err := url.Parse(s.AdminURL())
   581  	if err != nil {
   582  		t.Fatal(err)
   583  	}
   584  	encodedCookie, err := EncodeSessionCookie(cookie, false /* forHTTPSOnly */)
   585  	if err != nil {
   586  		t.Fatal(err)
   587  	}
   588  
   589  	invalidAuthClient, err := s.GetHTTPClient()
   590  	if err != nil {
   591  		t.Fatal(err)
   592  	}
   593  	jar, err := cookiejar.New(nil)
   594  	if err != nil {
   595  		t.Fatal(err)
   596  	}
   597  	invalidAuthClient.Jar = jar
   598  	invalidAuthClient.Jar.SetCookies(parsedURL, []*http.Cookie{encodedCookie})
   599  
   600  	invalidAuthResp, err := invalidAuthClient.Get(databasesURL)
   601  	if err != nil {
   602  		t.Fatal(err)
   603  	}
   604  	defer invalidAuthResp.Body.Close()
   605  
   606  	if invalidAuthResp.StatusCode != 401 {
   607  		t.Fatal("expected unauthorized error; got", invalidAuthResp.StatusCode)
   608  	}
   609  }
   610  
   611  // TestAuthenticationMux verifies that the authentication handler is used by all
   612  // of the APIs it should be protecting. Authentication is enabled by default for
   613  // the test server, and every test which accesses APIs uses an authenticated
   614  // client (except for a few that specifically override it).  Therefore, this
   615  // test verifies that authentication mux is attached to services at all by
   616  // testing an endpoint of each with a verified and unverified client.
   617  func TestAuthenticationMux(t *testing.T) {
   618  	defer leaktest.AfterTest(t)()
   619  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
   620  	defer s.Stopper().Stop(context.Background())
   621  	tsrv := s.(*TestServer)
   622  
   623  	// Both the normal and authenticated client will be used for each test.
   624  	normalClient, err := tsrv.GetHTTPClient()
   625  	if err != nil {
   626  		t.Fatal(err)
   627  	}
   628  	authClient, err := tsrv.GetAdminAuthenticatedHTTPClient()
   629  	if err != nil {
   630  		t.Fatal(err)
   631  	}
   632  
   633  	runRequest := func(
   634  		client http.Client, method string, path string, body []byte, expected int,
   635  	) error {
   636  		req, err := http.NewRequest(method, tsrv.AdminURL()+path, bytes.NewBuffer(body))
   637  		if err != nil {
   638  			return err
   639  		}
   640  		resp, err := client.Do(req)
   641  		if err != nil {
   642  			return err
   643  		}
   644  		defer resp.Body.Close()
   645  		if a, e := resp.StatusCode, expected; a != e {
   646  			message, err := ioutil.ReadAll(resp.Body)
   647  			if err != nil {
   648  				message = []byte(err.Error())
   649  			}
   650  			return errors.Errorf("got status code %d (msg %s), wanted %d", a, string(message), e)
   651  		}
   652  		return nil
   653  	}
   654  
   655  	// Generate request for time series API.
   656  	tsReq := tspb.TimeSeriesQueryRequest{
   657  		StartNanos: 0,
   658  		EndNanos:   100 * 1e9,
   659  		Queries:    []tspb.Query{{Name: "test.metric"}},
   660  	}
   661  	var tsReqBuffer bytes.Buffer
   662  	marshalFn := (&jsonpb.Marshaler{}).Marshal
   663  	if err := marshalFn(&tsReqBuffer, &tsReq); err != nil {
   664  		t.Fatal(err)
   665  	}
   666  
   667  	for _, tc := range []struct {
   668  		method string
   669  		path   string
   670  		body   []byte
   671  	}{
   672  		{"GET", adminPrefix + "users", nil},
   673  		{"GET", statusPrefix + "sessions", nil},
   674  		{"POST", ts.URLPrefix + "query", tsReqBuffer.Bytes()},
   675  	} {
   676  		t.Run("path="+tc.path, func(t *testing.T) {
   677  			// Verify normal client returns 401 Unauthorized.
   678  			if err := runRequest(normalClient, tc.method, tc.path, tc.body, http.StatusUnauthorized); err != nil {
   679  				t.Fatalf("request %s failed when not authorized: %s", tc.path, err)
   680  			}
   681  
   682  			// Verify authenticated client returns 200 OK.
   683  			if err := runRequest(authClient, tc.method, tc.path, tc.body, http.StatusOK); err != nil {
   684  				t.Fatalf("request %s failed when authorized: %s", tc.path, err)
   685  			}
   686  		})
   687  	}
   688  }
   689  
   690  func TestGRPCAuthentication(t *testing.T) {
   691  	defer leaktest.AfterTest(t)()
   692  
   693  	ctx := context.Background()
   694  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
   695  	defer s.Stopper().Stop(ctx)
   696  
   697  	// For each subsystem we pick a representative RPC. The idea is not to
   698  	// exhaustively test each RPC but to prevent server startup from being
   699  	// refactored in such a way that an entire subsystem becomes inadvertently
   700  	// exempt from authentication checks.
   701  	subsystems := []struct {
   702  		name    string
   703  		sendRPC func(context.Context, *grpc.ClientConn) error
   704  	}{
   705  		{"gossip", func(ctx context.Context, conn *grpc.ClientConn) error {
   706  			stream, err := gossip.NewGossipClient(conn).Gossip(ctx)
   707  			if err != nil {
   708  				return err
   709  			}
   710  			_ = stream.Send(&gossip.Request{})
   711  			_, err = stream.Recv()
   712  			return err
   713  		}},
   714  		{"internal", func(ctx context.Context, conn *grpc.ClientConn) error {
   715  			_, err := roachpb.NewInternalClient(conn).Batch(ctx, &roachpb.BatchRequest{})
   716  			return err
   717  		}},
   718  		{"perReplica", func(ctx context.Context, conn *grpc.ClientConn) error {
   719  			_, err := kvserver.NewPerReplicaClient(conn).CollectChecksum(ctx, &kvserver.CollectChecksumRequest{})
   720  			return err
   721  		}},
   722  		{"raft", func(ctx context.Context, conn *grpc.ClientConn) error {
   723  			stream, err := kvserver.NewMultiRaftClient(conn).RaftMessageBatch(ctx)
   724  			if err != nil {
   725  				return err
   726  			}
   727  			_ = stream.Send(&kvserver.RaftMessageRequestBatch{})
   728  			_, err = stream.Recv()
   729  			return err
   730  		}},
   731  		{"closedTimestamp", func(ctx context.Context, conn *grpc.ClientConn) error {
   732  			stream, err := ctpb.NewClosedTimestampClient(conn).Get(ctx)
   733  			if err != nil {
   734  				return err
   735  			}
   736  			_ = stream.Send(&ctpb.Reaction{})
   737  			_, err = stream.Recv()
   738  			return err
   739  		}},
   740  		{"distSQL", func(ctx context.Context, conn *grpc.ClientConn) error {
   741  			stream, err := execinfrapb.NewDistSQLClient(conn).RunSyncFlow(ctx)
   742  			if err != nil {
   743  				return err
   744  			}
   745  			_ = stream.Send(&execinfrapb.ConsumerSignal{})
   746  			_, err = stream.Recv()
   747  			return err
   748  		}},
   749  		{"init", func(ctx context.Context, conn *grpc.ClientConn) error {
   750  			_, err := serverpb.NewInitClient(conn).Bootstrap(ctx, &serverpb.BootstrapRequest{})
   751  			return err
   752  		}},
   753  		{"admin", func(ctx context.Context, conn *grpc.ClientConn) error {
   754  			_, err := serverpb.NewAdminClient(conn).Databases(ctx, &serverpb.DatabasesRequest{})
   755  			return err
   756  		}},
   757  		{"status", func(ctx context.Context, conn *grpc.ClientConn) error {
   758  			_, err := serverpb.NewStatusClient(conn).ListSessions(ctx, &serverpb.ListSessionsRequest{})
   759  			return err
   760  		}},
   761  	}
   762  
   763  	conn, err := grpc.DialContext(ctx, s.ServingRPCAddr(),
   764  		grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
   765  			InsecureSkipVerify: true,
   766  		})))
   767  	if err != nil {
   768  		t.Fatal(err)
   769  	}
   770  	defer func(conn *grpc.ClientConn) { _ = conn.Close() }(conn)
   771  	for _, subsystem := range subsystems {
   772  		t.Run(fmt.Sprintf("no-cert/%s", subsystem.name), func(t *testing.T) {
   773  			err := subsystem.sendRPC(ctx, conn)
   774  			if exp := "no client certificates in request"; !testutils.IsError(err, exp) {
   775  				t.Errorf("expected %q error, but got %v", exp, err)
   776  			}
   777  		})
   778  	}
   779  
   780  	certManager, err := s.RPCContext().GetCertificateManager()
   781  	if err != nil {
   782  		t.Fatal(err)
   783  	}
   784  	tlsConfig, err := certManager.GetClientTLSConfig("testuser")
   785  	if err != nil {
   786  		t.Fatal(err)
   787  	}
   788  	conn, err = grpc.DialContext(ctx, s.ServingRPCAddr(),
   789  		grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
   790  	if err != nil {
   791  		t.Fatal(err)
   792  	}
   793  	defer func(conn *grpc.ClientConn) { _ = conn.Close() }(conn)
   794  	for _, subsystem := range subsystems {
   795  		t.Run(fmt.Sprintf("bad-user/%s", subsystem.name), func(t *testing.T) {
   796  			err := subsystem.sendRPC(ctx, conn)
   797  			if exp := `user \[testuser\] is not allowed to perform this RPC`; !testutils.IsError(err, exp) {
   798  				t.Errorf("expected %q error, but got %v", exp, err)
   799  			}
   800  		})
   801  	}
   802  }