github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/pgwire/auth_test.go (about) 1 // Copyright 2019 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package pgwire_test 12 13 import ( 14 "context" 15 gosql "database/sql" 16 "fmt" 17 "io/ioutil" 18 "math" 19 "net" 20 "net/url" 21 "os" 22 "path/filepath" 23 "regexp" 24 "runtime" 25 "strconv" 26 "strings" 27 "testing" 28 29 "github.com/cockroachdb/cockroach/pkg/base" 30 "github.com/cockroachdb/cockroach/pkg/security" 31 "github.com/cockroachdb/cockroach/pkg/server" 32 "github.com/cockroachdb/cockroach/pkg/sql/pgwire" 33 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" 34 "github.com/cockroachdb/cockroach/pkg/testutils" 35 "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" 36 "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" 37 "github.com/cockroachdb/cockroach/pkg/util/leaktest" 38 "github.com/cockroachdb/cockroach/pkg/util/log" 39 "github.com/cockroachdb/datadriven" 40 "github.com/cockroachdb/errors" 41 "github.com/cockroachdb/errors/stdstrings" 42 "github.com/lib/pq" 43 ) 44 45 // TestAuthenticationAndHBARules exercises the authentication code 46 // using datadriven testing. 47 // 48 // It supports the following DSL: 49 // 50 // config [secure] [insecure] 51 // Only run the test file if the server is in the specified 52 // security mode. (The default is `config secure insecure` i.e. 53 // the test file is applicable to both.) 54 // 55 // set_hba 56 // <hba config> 57 // Load the provided HBA configuration via the cluster setting 58 // server.host_based_authentication.configuration. 59 // The expected output is the configuration after parsing 60 // and reloading in the server. 61 // 62 // sql 63 // <sql input> 64 // Execute the specified SQL statement using the default root 65 // connection provided by StartServer(). 66 // 67 // authlog N 68 // <regexp> 69 // Expect <regexp> at the end of the auth log then report the 70 // N entries before that. 71 // 72 // connect [key=value ...] 73 // Attempt a SQL connection using the provided connection 74 // parameters using the pg "DSN notation": k/v pairs separated 75 // by spaces. 76 // The following standard pg keys are recognized: 77 // user - the username 78 // password - the password 79 // host - the server name/address 80 // port - the server port 81 // sslmode, sslrootcert, sslcert, sslkey - SSL parameters. 82 // 83 // The order of k/v pairs matters: if the same key is specified 84 // multiple times, the first occurrence takes priority. 85 // 86 // Additionally, the test runner will always _append_ a default 87 // value for user (root), host/port/sslrootcert from the 88 // initialized test server. This default configuration is placed 89 // at the end so that each test can override the values. 90 // 91 // The test runner also adds a default value for sslcert and 92 // sslkey based on the value of "user" — either when provided by 93 // the test, or root by default. 94 // 95 // When the user is either "root" or "testuser" (those are the 96 // users for which the test server generates certificates), 97 // sslmode also gets a default of "verify-full". For other 98 // users, sslmode is initialized by default to "verify-ca". 99 // 100 // For the directives "sql" and "connect", the expected output can be 101 // either "ok" (no error) or "ERROR:" followed by the expected error 102 // string. 103 // The auth and connection log entries, if any, are also produced 104 // alongside the "ok" or "ERROR" message. 105 // 106 func TestAuthenticationAndHBARules(t *testing.T) { 107 defer leaktest.AfterTest(t)() 108 109 testutils.RunTrueAndFalse(t, "insecure", func(t *testing.T, insecure bool) { 110 hbaRunTest(t, insecure) 111 }) 112 } 113 114 const socketConnVirtualPort = "6" 115 116 func makeSocketFile(t *testing.T) (socketDir, socketFile string, cleanupFn func()) { 117 if runtime.GOOS == "windows" { 118 // Unix sockets not supported on windows. 119 return "", "", func() {} 120 } 121 // We need a temp directory in which we'll create the unix socket. 122 // 123 // On BSD, binding to a socket is limited to a path length of 104 characters 124 // (including the NUL terminator). In glibc, this limit is 108 characters. 125 // 126 // macOS has a tendency to produce very long temporary directory names, so 127 // we are careful to keep all the constants involved short. 128 tempDir, err := ioutil.TempDir("", "TestAuth") 129 if err != nil { 130 t.Fatal(err) 131 } 132 // ".s.PGSQL.NNNN" is the standard unix socket name supported by pg clients. 133 return tempDir, 134 filepath.Join(tempDir, ".s.PGSQL."+socketConnVirtualPort), 135 func() { _ = os.RemoveAll(tempDir) } 136 } 137 138 func hbaRunTest(t *testing.T, insecure bool) { 139 httpScheme := "http://" 140 if !insecure { 141 httpScheme = "https://" 142 } 143 datadriven.Walk(t, "testdata/auth", func(t *testing.T, path string) { 144 maybeSocketDir, maybeSocketFile, cleanup := makeSocketFile(t) 145 defer cleanup() 146 147 s, conn, _ := serverutils.StartServer(t, 148 base.TestServerArgs{Insecure: insecure, SocketFile: maybeSocketFile}) 149 defer s.Stopper().Stop(context.Background()) 150 151 // Enable conn/auth logging. 152 // We can't use the cluster settings to do this, because 153 // cluster settings propagate asynchronously. 154 s.(*server.TestServer).PGServer().TestingEnableConnAuthLogging() 155 156 // We really need to have the logs go to files, so that -show-logs 157 // does not break the "authlog" directives. We also must call 158 // this here and not earlier, because it needs to enforce the 159 // redirect on the secondary loggers created by StartServer(). 160 defer log.ScopeWithoutShowLogs(t).Close(t) 161 162 pgServer := s.(*server.TestServer).PGServer() 163 164 httpClient, err := s.GetHTTPClient() 165 if err != nil { 166 t.Fatal(err) 167 } 168 httpHBAUrl := httpScheme + s.HTTPAddr() + "/debug/hba_conf" 169 170 if _, err := conn.ExecContext(context.Background(), `CREATE USER $1`, server.TestUser); err != nil { 171 t.Fatal(err) 172 } 173 174 datadriven.RunTest(t, path, func(t *testing.T, td *datadriven.TestData) string { 175 resultString, err := func() (string, error) { 176 switch td.Cmd { 177 case "config": 178 allowed := false 179 for _, a := range td.CmdArgs { 180 switch a.Key { 181 case "secure": 182 allowed = allowed || !insecure 183 case "insecure": 184 allowed = allowed || insecure 185 default: 186 t.Fatalf("unknown configuration: %s", a.Key) 187 } 188 } 189 if !allowed { 190 t.Skip("Test file not applicable at this security level.") 191 } 192 193 case "set_hba": 194 _, err := conn.ExecContext(context.Background(), 195 `SET CLUSTER SETTING server.host_based_authentication.configuration = $1`, td.Input) 196 if err != nil { 197 return "", err 198 } 199 200 // Wait until the configuration has propagated back to the 201 // test client. We need to wait because the cluster setting 202 // change propagates asynchronously. 203 expConf := pgwire.DefaultHBAConfig 204 if td.Input != "" { 205 expConf, err = pgwire.ParseAndNormalize(td.Input) 206 if err != nil { 207 // The SET above succeeded so we don't expect a problem here. 208 t.Fatal(err) 209 } 210 } 211 testutils.SucceedsSoon(t, func() error { 212 curConf := pgServer.GetAuthenticationConfiguration() 213 if expConf.String() != curConf.String() { 214 return errors.Newf( 215 "HBA config not yet loaded\ngot:\n%s\nexpected:\n%s", 216 curConf, expConf) 217 } 218 return nil 219 }) 220 221 // Verify the HBA configuration was processed properly by 222 // reporting the resulting cached configuration. 223 resp, err := httpClient.Get(httpHBAUrl) 224 if err != nil { 225 return "", err 226 } 227 defer resp.Body.Close() 228 body, err := ioutil.ReadAll(resp.Body) 229 if err != nil { 230 return "", err 231 } 232 return string(body), nil 233 234 case "sql": 235 _, err := conn.ExecContext(context.Background(), td.Input) 236 return "ok", err 237 238 case "authlog": 239 if len(td.CmdArgs) < 0 { 240 t.Fatal("not enough arguments") 241 } 242 numEntries, err := strconv.Atoi(td.CmdArgs[0].Key) 243 if err != nil { 244 t.Fatal(err) 245 } 246 re, err := regexp.Compile(td.Input) 247 if err != nil { 248 t.Fatal(err) 249 } 250 251 var buf strings.Builder 252 if err := testutils.SucceedsSoonError(func() error { 253 buf.Reset() 254 t.Logf("attempting to scan logs...") 255 256 // Note: even though FetchEntriesFromFiles advertises a mechanism 257 // to filter entries by timestamp or just retrieve the last N entries, 258 // this is currently broken for secondary loggers. 259 // See: https://github.com/cockroachdb/cockroach/issues/45745 260 // So instead we need to do the filtering ourselves. 261 entries, err := log.FetchEntriesFromFiles(0, math.MaxInt64, 10000, authLogFileRe) 262 if err != nil { 263 t.Fatal(err) 264 } 265 if len(entries) == 0 { 266 return errors.New("no log entries") 267 } else { 268 // Note: entries are delivered by Fetch in reverse order. 269 i := numEntries - 1 270 if i < 0 || i >= len(entries) { 271 i = len(entries) - 1 272 } 273 for ; i >= 0; i-- { 274 entry := &entries[i] 275 t.Logf("found log entry: %+v", *entry) 276 277 // The message is going to contain a client address, with a random port number. 278 // To make the test deterministic, erase the random part. 279 msg := addrRe.ReplaceAllString(entry.Message, ",client=XXX") 280 // Ditto with the duration. 281 msg = durationRe.ReplaceAllString(msg, "duration: XXX") 282 283 fmt.Fprintf(&buf, "%c: %s\n", entry.Severity.String()[0], msg) 284 } 285 lastLogMsg := entries[0].Message 286 if !re.MatchString(lastLogMsg) { 287 return errors.Newf("last entry does not match: %q", lastLogMsg) 288 } 289 } 290 return nil 291 }); err != nil { 292 buf.WriteString("ERROR: unable to find log line matching regexp") 293 } 294 return buf.String(), nil 295 296 case "connect", "connect_unix": 297 if td.Cmd == "connect_unix" && runtime.GOOS == "windows" { 298 // Unix sockets not supported; assume the test succeeded. 299 return td.Expected, nil 300 } 301 302 // Prepare a connection string using the server's default. 303 // What is the user requested by the test? 304 user := security.RootUser 305 if td.HasArg("user") { 306 td.ScanArgs(t, "user", &user) 307 } 308 309 // We want the certs to be present in the filesystem for this test. 310 // However, certs are only generated for users "root" and "testuser" specifically. 311 sqlURL, cleanupFn := sqlutils.PGUrlWithOptionalClientCerts( 312 t, s.ServingSQLAddr(), t.Name(), url.User(user), 313 user == security.RootUser || user == server.TestUser /* withClientCerts */) 314 defer cleanupFn() 315 316 var host, port string 317 if td.Cmd == "connect" { 318 host, port, err = net.SplitHostPort(s.ServingSQLAddr()) 319 if err != nil { 320 t.Fatal(err) 321 } 322 } else /* unix */ { 323 host = maybeSocketDir 324 port = socketConnVirtualPort 325 } 326 options, err := url.ParseQuery(sqlURL.RawQuery) 327 if err != nil { 328 t.Fatal(err) 329 } 330 331 // Here we make use of the fact that pq accepts connection 332 // strings using the alternate postgres configuration format, 333 // consisting of k=v pairs separated by spaces. 334 // For example this is a valid connection string: 335 // "host=localhost port=5432 user=root" 336 // We also make use of the datadriven K/V parsing facility, 337 // which always prioritizes the first K instance in the test's 338 // argument list. We append the server's config parameters 339 // at the end, letting the test override by introducing its 340 // own values at the beginning. 341 args := append(td.CmdArgs, 342 datadriven.CmdArg{Key: "user", Vals: []string{user}}, 343 datadriven.CmdArg{Key: "host", Vals: []string{host}}, 344 datadriven.CmdArg{Key: "port", Vals: []string{port}}, 345 ) 346 for key := range options { 347 args = append(args, 348 datadriven.CmdArg{Key: key, Vals: []string{options.Get(key)}}) 349 } 350 // Now turn the cmdargs into a dsn. 351 var dsnBuf strings.Builder 352 sp := "" 353 seenKeys := map[string]struct{}{} 354 for _, a := range args { 355 if _, ok := seenKeys[a.Key]; ok { 356 continue 357 } 358 seenKeys[a.Key] = struct{}{} 359 val := "" 360 if len(a.Vals) > 0 { 361 val = a.Vals[0] 362 } 363 fmt.Fprintf(&dsnBuf, "%s%s=%s", sp, a.Key, val) 364 sp = " " 365 } 366 dsn := dsnBuf.String() 367 368 // Finally, connect and test the connection. 369 dbSQL, err := gosql.Open("postgres", dsn) 370 if dbSQL != nil { 371 // Note: gosql.Open may return a valid db (with an open 372 // TCP connection) even if there is an error. We want to 373 // ensure this gets closed so that we catch the conn close 374 // message in the log. 375 defer dbSQL.Close() 376 } 377 if err != nil { 378 return "", err 379 } 380 row := dbSQL.QueryRow("SELECT current_catalog") 381 var dbName string 382 if err := row.Scan(&dbName); err != nil { 383 return "", err 384 } 385 return "ok " + dbName, nil 386 387 default: 388 td.Fatalf(t, "unknown command: %s", td.Cmd) 389 } 390 return "", nil 391 }() 392 if err != nil { 393 return fmtErr(err) 394 } 395 return resultString 396 }) 397 }) 398 } 399 400 var authLogFileRe = regexp.MustCompile(`pgwire/(auth|conn|server)\.go`) 401 var addrRe = regexp.MustCompile(`,client(=[^\],]*)?`) 402 var durationRe = regexp.MustCompile(`duration: \d.*s`) 403 404 // fmtErr formats an error into an expected output. 405 func fmtErr(err error) string { 406 if err != nil { 407 errStr := "" 408 if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) { 409 errStr = pqErr.Message 410 if pqErr.Code != pgcode.Uncategorized { 411 errStr += fmt.Sprintf(" (SQLSTATE %s)", pqErr.Code) 412 } 413 if pqErr.Hint != "" { 414 hint := strings.Replace(pqErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1) 415 errStr += "\nHINT: " + hint 416 } 417 if pqErr.Detail != "" { 418 errStr += "\nDETAIL: " + pqErr.Detail 419 } 420 } else { 421 errStr = err.Error() 422 } 423 return "ERROR: " + errStr 424 } 425 return "ok" 426 }