github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/pgwire/auth_test.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package pgwire_test
    12  
    13  import (
    14  	"context"
    15  	gosql "database/sql"
    16  	"fmt"
    17  	"io/ioutil"
    18  	"math"
    19  	"net"
    20  	"net/url"
    21  	"os"
    22  	"path/filepath"
    23  	"regexp"
    24  	"runtime"
    25  	"strconv"
    26  	"strings"
    27  	"testing"
    28  
    29  	"github.com/cockroachdb/cockroach/pkg/base"
    30  	"github.com/cockroachdb/cockroach/pkg/security"
    31  	"github.com/cockroachdb/cockroach/pkg/server"
    32  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire"
    33  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    34  	"github.com/cockroachdb/cockroach/pkg/testutils"
    35  	"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
    36  	"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
    37  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    38  	"github.com/cockroachdb/cockroach/pkg/util/log"
    39  	"github.com/cockroachdb/datadriven"
    40  	"github.com/cockroachdb/errors"
    41  	"github.com/cockroachdb/errors/stdstrings"
    42  	"github.com/lib/pq"
    43  )
    44  
    45  // TestAuthenticationAndHBARules exercises the authentication code
    46  // using datadriven testing.
    47  //
    48  // It supports the following DSL:
    49  //
    50  // config [secure] [insecure]
    51  //       Only run the test file if the server is in the specified
    52  //       security mode. (The default is `config secure insecure` i.e.
    53  //       the test file is applicable to both.)
    54  //
    55  // set_hba
    56  // <hba config>
    57  //       Load the provided HBA configuration via the cluster setting
    58  //       server.host_based_authentication.configuration.
    59  //       The expected output is the configuration after parsing
    60  //       and reloading in the server.
    61  //
    62  // sql
    63  // <sql input>
    64  //       Execute the specified SQL statement using the default root
    65  //       connection provided by StartServer().
    66  //
    67  // authlog N
    68  // <regexp>
    69  //       Expect <regexp> at the end of the auth log then report the
    70  //       N entries before that.
    71  //
    72  // connect [key=value ...]
    73  //       Attempt a SQL connection using the provided connection
    74  //       parameters using the pg "DSN notation": k/v pairs separated
    75  //       by spaces.
    76  //       The following standard pg keys are recognized:
    77  //            user - the username
    78  //            password - the password
    79  //            host - the server name/address
    80  //            port - the server port
    81  //            sslmode, sslrootcert, sslcert, sslkey - SSL parameters.
    82  //
    83  //       The order of k/v pairs matters: if the same key is specified
    84  //       multiple times, the first occurrence takes priority.
    85  //
    86  //       Additionally, the test runner will always _append_ a default
    87  //       value for user (root), host/port/sslrootcert from the
    88  //       initialized test server. This default configuration is placed
    89  //       at the end so that each test can override the values.
    90  //
    91  //       The test runner also adds a default value for sslcert and
    92  //       sslkey based on the value of "user" — either when provided by
    93  //       the test, or root by default.
    94  //
    95  //       When the user is either "root" or "testuser" (those are the
    96  //       users for which the test server generates certificates),
    97  //       sslmode also gets a default of "verify-full". For other
    98  //       users, sslmode is initialized by default to "verify-ca".
    99  //
   100  // For the directives "sql" and "connect", the expected output can be
   101  // either "ok" (no error) or "ERROR:" followed by the expected error
   102  // string.
   103  // The auth and connection log entries, if any, are also produced
   104  // alongside the "ok" or "ERROR" message.
   105  //
   106  func TestAuthenticationAndHBARules(t *testing.T) {
   107  	defer leaktest.AfterTest(t)()
   108  
   109  	testutils.RunTrueAndFalse(t, "insecure", func(t *testing.T, insecure bool) {
   110  		hbaRunTest(t, insecure)
   111  	})
   112  }
   113  
   114  const socketConnVirtualPort = "6"
   115  
   116  func makeSocketFile(t *testing.T) (socketDir, socketFile string, cleanupFn func()) {
   117  	if runtime.GOOS == "windows" {
   118  		// Unix sockets not supported on windows.
   119  		return "", "", func() {}
   120  	}
   121  	// We need a temp directory in which we'll create the unix socket.
   122  	//
   123  	// On BSD, binding to a socket is limited to a path length of 104 characters
   124  	// (including the NUL terminator). In glibc, this limit is 108 characters.
   125  	//
   126  	// macOS has a tendency to produce very long temporary directory names, so
   127  	// we are careful to keep all the constants involved short.
   128  	tempDir, err := ioutil.TempDir("", "TestAuth")
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  	// ".s.PGSQL.NNNN" is the standard unix socket name supported by pg clients.
   133  	return tempDir,
   134  		filepath.Join(tempDir, ".s.PGSQL."+socketConnVirtualPort),
   135  		func() { _ = os.RemoveAll(tempDir) }
   136  }
   137  
   138  func hbaRunTest(t *testing.T, insecure bool) {
   139  	httpScheme := "http://"
   140  	if !insecure {
   141  		httpScheme = "https://"
   142  	}
   143  	datadriven.Walk(t, "testdata/auth", func(t *testing.T, path string) {
   144  		maybeSocketDir, maybeSocketFile, cleanup := makeSocketFile(t)
   145  		defer cleanup()
   146  
   147  		s, conn, _ := serverutils.StartServer(t,
   148  			base.TestServerArgs{Insecure: insecure, SocketFile: maybeSocketFile})
   149  		defer s.Stopper().Stop(context.Background())
   150  
   151  		// Enable conn/auth logging.
   152  		// We can't use the cluster settings to do this, because
   153  		// cluster settings propagate asynchronously.
   154  		s.(*server.TestServer).PGServer().TestingEnableConnAuthLogging()
   155  
   156  		// We really need to have the logs go to files, so that -show-logs
   157  		// does not break the "authlog" directives. We also must call
   158  		// this here and not earlier, because it needs to enforce the
   159  		// redirect on the secondary loggers created by StartServer().
   160  		defer log.ScopeWithoutShowLogs(t).Close(t)
   161  
   162  		pgServer := s.(*server.TestServer).PGServer()
   163  
   164  		httpClient, err := s.GetHTTPClient()
   165  		if err != nil {
   166  			t.Fatal(err)
   167  		}
   168  		httpHBAUrl := httpScheme + s.HTTPAddr() + "/debug/hba_conf"
   169  
   170  		if _, err := conn.ExecContext(context.Background(), `CREATE USER $1`, server.TestUser); err != nil {
   171  			t.Fatal(err)
   172  		}
   173  
   174  		datadriven.RunTest(t, path, func(t *testing.T, td *datadriven.TestData) string {
   175  			resultString, err := func() (string, error) {
   176  				switch td.Cmd {
   177  				case "config":
   178  					allowed := false
   179  					for _, a := range td.CmdArgs {
   180  						switch a.Key {
   181  						case "secure":
   182  							allowed = allowed || !insecure
   183  						case "insecure":
   184  							allowed = allowed || insecure
   185  						default:
   186  							t.Fatalf("unknown configuration: %s", a.Key)
   187  						}
   188  					}
   189  					if !allowed {
   190  						t.Skip("Test file not applicable at this security level.")
   191  					}
   192  
   193  				case "set_hba":
   194  					_, err := conn.ExecContext(context.Background(),
   195  						`SET CLUSTER SETTING server.host_based_authentication.configuration = $1`, td.Input)
   196  					if err != nil {
   197  						return "", err
   198  					}
   199  
   200  					// Wait until the configuration has propagated back to the
   201  					// test client. We need to wait because the cluster setting
   202  					// change propagates asynchronously.
   203  					expConf := pgwire.DefaultHBAConfig
   204  					if td.Input != "" {
   205  						expConf, err = pgwire.ParseAndNormalize(td.Input)
   206  						if err != nil {
   207  							// The SET above succeeded so we don't expect a problem here.
   208  							t.Fatal(err)
   209  						}
   210  					}
   211  					testutils.SucceedsSoon(t, func() error {
   212  						curConf := pgServer.GetAuthenticationConfiguration()
   213  						if expConf.String() != curConf.String() {
   214  							return errors.Newf(
   215  								"HBA config not yet loaded\ngot:\n%s\nexpected:\n%s",
   216  								curConf, expConf)
   217  						}
   218  						return nil
   219  					})
   220  
   221  					// Verify the HBA configuration was processed properly by
   222  					// reporting the resulting cached configuration.
   223  					resp, err := httpClient.Get(httpHBAUrl)
   224  					if err != nil {
   225  						return "", err
   226  					}
   227  					defer resp.Body.Close()
   228  					body, err := ioutil.ReadAll(resp.Body)
   229  					if err != nil {
   230  						return "", err
   231  					}
   232  					return string(body), nil
   233  
   234  				case "sql":
   235  					_, err := conn.ExecContext(context.Background(), td.Input)
   236  					return "ok", err
   237  
   238  				case "authlog":
   239  					if len(td.CmdArgs) < 0 {
   240  						t.Fatal("not enough arguments")
   241  					}
   242  					numEntries, err := strconv.Atoi(td.CmdArgs[0].Key)
   243  					if err != nil {
   244  						t.Fatal(err)
   245  					}
   246  					re, err := regexp.Compile(td.Input)
   247  					if err != nil {
   248  						t.Fatal(err)
   249  					}
   250  
   251  					var buf strings.Builder
   252  					if err := testutils.SucceedsSoonError(func() error {
   253  						buf.Reset()
   254  						t.Logf("attempting to scan logs...")
   255  
   256  						// Note: even though FetchEntriesFromFiles advertises a mechanism
   257  						// to filter entries by timestamp or just retrieve the last N entries,
   258  						// this is currently broken for secondary loggers.
   259  						// See: https://github.com/cockroachdb/cockroach/issues/45745
   260  						// So instead we need to do the filtering ourselves.
   261  						entries, err := log.FetchEntriesFromFiles(0, math.MaxInt64, 10000, authLogFileRe)
   262  						if err != nil {
   263  							t.Fatal(err)
   264  						}
   265  						if len(entries) == 0 {
   266  							return errors.New("no log entries")
   267  						} else {
   268  							// Note: entries are delivered by Fetch in reverse order.
   269  							i := numEntries - 1
   270  							if i < 0 || i >= len(entries) {
   271  								i = len(entries) - 1
   272  							}
   273  							for ; i >= 0; i-- {
   274  								entry := &entries[i]
   275  								t.Logf("found log entry: %+v", *entry)
   276  
   277  								// The message is going to contain a client address, with a random port number.
   278  								// To make the test deterministic, erase the random part.
   279  								msg := addrRe.ReplaceAllString(entry.Message, ",client=XXX")
   280  								// Ditto with the duration.
   281  								msg = durationRe.ReplaceAllString(msg, "duration: XXX")
   282  
   283  								fmt.Fprintf(&buf, "%c: %s\n", entry.Severity.String()[0], msg)
   284  							}
   285  							lastLogMsg := entries[0].Message
   286  							if !re.MatchString(lastLogMsg) {
   287  								return errors.Newf("last entry does not match: %q", lastLogMsg)
   288  							}
   289  						}
   290  						return nil
   291  					}); err != nil {
   292  						buf.WriteString("ERROR: unable to find log line matching regexp")
   293  					}
   294  					return buf.String(), nil
   295  
   296  				case "connect", "connect_unix":
   297  					if td.Cmd == "connect_unix" && runtime.GOOS == "windows" {
   298  						// Unix sockets not supported; assume the test succeeded.
   299  						return td.Expected, nil
   300  					}
   301  
   302  					// Prepare a connection string using the server's default.
   303  					// What is the user requested by the test?
   304  					user := security.RootUser
   305  					if td.HasArg("user") {
   306  						td.ScanArgs(t, "user", &user)
   307  					}
   308  
   309  					// We want the certs to be present in the filesystem for this test.
   310  					// However, certs are only generated for users "root" and "testuser" specifically.
   311  					sqlURL, cleanupFn := sqlutils.PGUrlWithOptionalClientCerts(
   312  						t, s.ServingSQLAddr(), t.Name(), url.User(user),
   313  						user == security.RootUser || user == server.TestUser /* withClientCerts */)
   314  					defer cleanupFn()
   315  
   316  					var host, port string
   317  					if td.Cmd == "connect" {
   318  						host, port, err = net.SplitHostPort(s.ServingSQLAddr())
   319  						if err != nil {
   320  							t.Fatal(err)
   321  						}
   322  					} else /* unix */ {
   323  						host = maybeSocketDir
   324  						port = socketConnVirtualPort
   325  					}
   326  					options, err := url.ParseQuery(sqlURL.RawQuery)
   327  					if err != nil {
   328  						t.Fatal(err)
   329  					}
   330  
   331  					// Here we make use of the fact that pq accepts connection
   332  					// strings using the alternate postgres configuration format,
   333  					// consisting of k=v pairs separated by spaces.
   334  					// For example this is a valid connection string:
   335  					//  "host=localhost port=5432 user=root"
   336  					// We also make use of the datadriven K/V parsing facility,
   337  					// which always prioritizes the first K instance in the test's
   338  					// argument list. We append the server's config parameters
   339  					// at the end, letting the test override by introducing its
   340  					// own values at the beginning.
   341  					args := append(td.CmdArgs,
   342  						datadriven.CmdArg{Key: "user", Vals: []string{user}},
   343  						datadriven.CmdArg{Key: "host", Vals: []string{host}},
   344  						datadriven.CmdArg{Key: "port", Vals: []string{port}},
   345  					)
   346  					for key := range options {
   347  						args = append(args,
   348  							datadriven.CmdArg{Key: key, Vals: []string{options.Get(key)}})
   349  					}
   350  					// Now turn the cmdargs into a dsn.
   351  					var dsnBuf strings.Builder
   352  					sp := ""
   353  					seenKeys := map[string]struct{}{}
   354  					for _, a := range args {
   355  						if _, ok := seenKeys[a.Key]; ok {
   356  							continue
   357  						}
   358  						seenKeys[a.Key] = struct{}{}
   359  						val := ""
   360  						if len(a.Vals) > 0 {
   361  							val = a.Vals[0]
   362  						}
   363  						fmt.Fprintf(&dsnBuf, "%s%s=%s", sp, a.Key, val)
   364  						sp = " "
   365  					}
   366  					dsn := dsnBuf.String()
   367  
   368  					// Finally, connect and test the connection.
   369  					dbSQL, err := gosql.Open("postgres", dsn)
   370  					if dbSQL != nil {
   371  						// Note: gosql.Open may return a valid db (with an open
   372  						// TCP connection) even if there is an error. We want to
   373  						// ensure this gets closed so that we catch the conn close
   374  						// message in the log.
   375  						defer dbSQL.Close()
   376  					}
   377  					if err != nil {
   378  						return "", err
   379  					}
   380  					row := dbSQL.QueryRow("SELECT current_catalog")
   381  					var dbName string
   382  					if err := row.Scan(&dbName); err != nil {
   383  						return "", err
   384  					}
   385  					return "ok " + dbName, nil
   386  
   387  				default:
   388  					td.Fatalf(t, "unknown command: %s", td.Cmd)
   389  				}
   390  				return "", nil
   391  			}()
   392  			if err != nil {
   393  				return fmtErr(err)
   394  			}
   395  			return resultString
   396  		})
   397  	})
   398  }
   399  
   400  var authLogFileRe = regexp.MustCompile(`pgwire/(auth|conn|server)\.go`)
   401  var addrRe = regexp.MustCompile(`,client(=[^\],]*)?`)
   402  var durationRe = regexp.MustCompile(`duration: \d.*s`)
   403  
   404  // fmtErr formats an error into an expected output.
   405  func fmtErr(err error) string {
   406  	if err != nil {
   407  		errStr := ""
   408  		if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) {
   409  			errStr = pqErr.Message
   410  			if pqErr.Code != pgcode.Uncategorized {
   411  				errStr += fmt.Sprintf(" (SQLSTATE %s)", pqErr.Code)
   412  			}
   413  			if pqErr.Hint != "" {
   414  				hint := strings.Replace(pqErr.Hint, stdstrings.IssueReferral, "<STANDARD REFERRAL>", 1)
   415  				errStr += "\nHINT: " + hint
   416  			}
   417  			if pqErr.Detail != "" {
   418  				errStr += "\nDETAIL: " + pqErr.Detail
   419  			}
   420  		} else {
   421  			errStr = err.Error()
   422  		}
   423  		return "ERROR: " + errStr
   424  	}
   425  	return "ok"
   426  }