github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/ccl/serverccl/role_authentication_test.go (about)

     1  // Copyright 2020 The Cockroach Authors.
     2  //
     3  // Licensed as a CockroachDB Enterprise file under the Cockroach Community
     4  // License (the "License"); you may not use this file except in compliance with
     5  // the License. You may obtain a copy of the License at
     6  //
     7  //     https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt
     8  
     9  package serverccl
    10  
    11  import (
    12  	"context"
    13  	"fmt"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/cockroachdb/cockroach/pkg/base"
    18  	"github.com/cockroachdb/cockroach/pkg/security"
    19  	"github.com/cockroachdb/cockroach/pkg/server"
    20  	"github.com/cockroachdb/cockroach/pkg/sql"
    21  	"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
    22  	"github.com/cockroachdb/cockroach/pkg/util"
    23  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    24  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    25  	"golang.org/x/crypto/bcrypt"
    26  )
    27  
    28  func TestVerifyPassword(t *testing.T) {
    29  	defer leaktest.AfterTest(t)()
    30  
    31  	ctx := context.Background()
    32  	s, db, _ := serverutils.StartServer(t, base.TestServerArgs{})
    33  	defer s.Stopper().Stop(ctx)
    34  
    35  	ie := sql.MakeInternalExecutor(
    36  		context.Background(),
    37  		s.(*server.TestServer).Server.PGServer().SQLServer,
    38  		sql.MemoryMetrics{},
    39  		s.ExecutorConfig().(sql.ExecutorConfig).Settings,
    40  	)
    41  
    42  	if util.RaceEnabled {
    43  		// The default bcrypt cost makes this test approximately 30s slower when the
    44  		// race detector is on.
    45  		defer func(prev int) { security.BcryptCost = prev }(security.BcryptCost)
    46  		security.BcryptCost = bcrypt.MinCost
    47  	}
    48  
    49  	//location is used for timezone testing.
    50  	shanghaiLoc, err := time.LoadLocation("Asia/Shanghai")
    51  	if err != nil {
    52  		t.Fatal(err)
    53  	}
    54  
    55  	for _, user := range []struct {
    56  		username         string
    57  		password         string
    58  		loginFlag        string
    59  		validUntilClause string
    60  		qargs            []interface{}
    61  	}{
    62  		{"azure_diamond", "hunter2", "LOGIN", "", nil},
    63  		{"druidia", "12345", "LOGIN", "", nil},
    64  
    65  		{"richardc", "12345", "NOLOGIN", "", nil},
    66  		{"before_epoch", "12345", "", "VALID UNTIL '1969-01-01'", nil},
    67  		{"epoch", "12345", "", "VALID UNTIL '1970-01-01'", nil},
    68  		{"cockroach", "12345", "", "VALID UNTIL '2100-01-01'", nil},
    69  		{"cthon98", "12345", "", "VALID UNTIL NULL", nil},
    70  
    71  		{"toolate", "12345", "", "VALID UNTIL $1",
    72  			[]interface{}{timeutil.Now().Add(-10 * time.Minute)}},
    73  		{"timelord", "12345", "", "VALID UNTIL $1",
    74  			[]interface{}{timeutil.Now().Add(59 * time.Minute).In(shanghaiLoc)}},
    75  	} {
    76  		cmd := fmt.Sprintf(
    77  			"CREATE ROLE %s WITH PASSWORD '%s' %s %s",
    78  			user.username, user.password, user.loginFlag, user.validUntilClause)
    79  
    80  		if _, err := db.Exec(cmd, user.qargs...); err != nil {
    81  			t.Fatalf("failed to create user: %s", err)
    82  		}
    83  	}
    84  
    85  	for _, tc := range []struct {
    86  		username           string
    87  		password           string
    88  		shouldAuthenticate bool
    89  		expectedErrString  string
    90  	}{
    91  		{"azure_diamond", "hunter2", true, ""},
    92  		{"azure_diamond", "hunter", false, "crypto/bcrypt"},
    93  		{"azure_diamond", "", false, "crypto/bcrypt"},
    94  		{"azure_diamond", "🍦", false, "crypto/bcrypt"},
    95  		{"azure_diamond", "hunter2345", false, "crypto/bcrypt"},
    96  		{"azure_diamond", "shunter2", false, "crypto/bcrypt"},
    97  		{"azure_diamond", "12345", false, "crypto/bcrypt"},
    98  		{"azure_diamond", "*******", false, "crypto/bcrypt"},
    99  		{"druidia", "12345", true, ""},
   100  		{"druidia", "hunter2", false, "crypto/bcrypt"},
   101  		{"root", "", false, "crypto/bcrypt"},
   102  		{"", "", false, "does not exist"},
   103  		{"doesntexist", "zxcvbn", false, "does not exist"},
   104  
   105  		{"richardc", "12345", false,
   106  			"richardc does not have login privilege"},
   107  		{"before_epoch", "12345", false, ""},
   108  		{"epoch", "12345", false, ""},
   109  		{"cockroach", "12345", true, ""},
   110  		{"toolate", "12345", false, ""},
   111  		{"timelord", "12345", true, ""},
   112  		{"cthon98", "12345", true, ""},
   113  	} {
   114  		t.Run("", func(t *testing.T) {
   115  			exists, canLogin, pwRetrieveFn, validUntilFn, err := sql.GetUserHashedPassword(context.Background(), &ie, tc.username)
   116  
   117  			if err != nil {
   118  				t.Errorf(
   119  					"credentials %s/%s failed with error %s, wanted no error",
   120  					tc.username,
   121  					tc.password,
   122  					err,
   123  				)
   124  			}
   125  
   126  			valid := true
   127  			expired := false
   128  
   129  			if !exists || !canLogin {
   130  				valid = false
   131  			}
   132  
   133  			hashedPassword, err := pwRetrieveFn(ctx)
   134  			if err != nil {
   135  				t.Errorf(
   136  					"credentials %s/%s failed with error %s, wanted no error",
   137  					tc.username,
   138  					tc.password,
   139  					err,
   140  				)
   141  			}
   142  
   143  			err = security.CompareHashAndPassword(hashedPassword, tc.password)
   144  			if err != nil {
   145  				valid = false
   146  			}
   147  
   148  			validUntil, err := validUntilFn(ctx)
   149  			if err != nil {
   150  				t.Errorf(
   151  					"credentials %s/%s failed with error %s, wanted no error",
   152  					tc.username,
   153  					tc.password,
   154  					err,
   155  				)
   156  			}
   157  
   158  			if validUntil != nil {
   159  				if validUntil.Time.Sub(timeutil.Now()) < 0 {
   160  					expired = true
   161  				}
   162  			}
   163  
   164  			if valid && !expired != tc.shouldAuthenticate {
   165  				t.Errorf(
   166  					"credentials %s/%s valid = %t, wanted %t",
   167  					tc.username,
   168  					tc.password,
   169  					valid,
   170  					tc.shouldAuthenticate,
   171  				)
   172  			}
   173  		})
   174  	}
   175  }