github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/pgwire/hba/parser.go (about)

     1  // Copyright 2020 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package hba
    12  
    13  import (
    14  	"fmt"
    15  	"net"
    16  	"strings"
    17  
    18  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    19  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    20  	"github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented"
    21  	"github.com/cockroachdb/errors"
    22  )
    23  
    24  // scannedInput represents the result of tokenizing the input
    25  // configuration data.
    26  //
    27  // Inspired from pg's source, file src/backend/libpq/hba.c,
    28  // function tokenize_file.
    29  //
    30  // The scanner tokenizes the input and stores the resulting data into
    31  // three lists: a list of lines, a list of line numbers, and a list of
    32  // raw line contents.
    33  type scannedInput struct {
    34  	// The list of lines is a triple-nested list structure.  Each line is a list of
    35  	// fields, and each field is a List of tokens.
    36  	lines   []hbaLine
    37  	linenos []int
    38  }
    39  
    40  type hbaLine struct {
    41  	input  string
    42  	tokens [][]String
    43  }
    44  
    45  // Parse parses the provided HBA configuration.
    46  func Parse(input string) (*Conf, error) {
    47  	tokens, err := tokenize(input)
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	var entries []Entry
    53  	for i, line := range tokens.lines {
    54  		entry, err := parseHbaLine(line)
    55  		if err != nil {
    56  			return nil, errors.Wrapf(
    57  				pgerror.WithCandidateCode(err, pgcode.ConfigFile),
    58  				"line %d", tokens.linenos[i])
    59  		}
    60  		entries = append(entries, entry)
    61  	}
    62  
    63  	return &Conf{Entries: entries}, nil
    64  }
    65  
    66  // parseHbaLine parses one line of HBA configuration.
    67  //
    68  // Inspired from pg's src/backend/libpq/hba.c, parse_hba_line().
    69  func parseHbaLine(inputLine hbaLine) (entry Entry, err error) {
    70  	fieldIdx := 0
    71  
    72  	entry.Input = inputLine.input
    73  	line := inputLine.tokens
    74  	// Read the connection type.
    75  	if len(line[fieldIdx]) > 1 {
    76  		return entry, errors.WithHint(
    77  			errors.New("multiple values specified for connection type"),
    78  			"Specify exactly one connection type per line.")
    79  	}
    80  	entry.ConnType, err = ParseConnType(line[fieldIdx][0].Value)
    81  	if err != nil {
    82  		return entry, err
    83  	}
    84  
    85  	// Get the databases.
    86  	fieldIdx++
    87  	if fieldIdx >= len(line) {
    88  		return entry, errors.New("end-of-line before database specification")
    89  	}
    90  	entry.Database = line[fieldIdx]
    91  
    92  	// Get the roles.
    93  	fieldIdx++
    94  	if fieldIdx >= len(line) {
    95  		return entry, errors.New("end-of-line before role specification")
    96  	}
    97  	entry.User = line[fieldIdx]
    98  
    99  	if entry.ConnType != ConnLocal {
   100  		fieldIdx++
   101  		if fieldIdx >= len(line) {
   102  			return entry, errors.New("end-of-line before IP address specification")
   103  		}
   104  		tokens := line[fieldIdx]
   105  		if len(tokens) > 1 {
   106  			return entry, errors.WithHint(
   107  				errors.New("multiple values specified for host address"),
   108  				"Specify one address range per line.")
   109  		}
   110  		token := tokens[0]
   111  		switch {
   112  		case token.Value == "":
   113  			return entry, errors.New("cannot use empty string as address")
   114  		case token.IsKeyword("all"):
   115  			entry.Address = token
   116  		case token.IsKeyword("samehost"), token.IsKeyword("samenet"):
   117  			return entry, unimplemented.Newf(
   118  				fmt.Sprintf("hba-net-%s", token.Value),
   119  				"address specification %s is not yet supported", errors.Safe(token.Value))
   120  		default:
   121  			// Split name/mask.
   122  			addr := token.Value
   123  			if strings.Contains(addr, "/") {
   124  				_, ipnet, err := net.ParseCIDR(addr)
   125  				if err != nil {
   126  					return entry, err
   127  				}
   128  				entry.Address = ipnet
   129  			} else {
   130  				var ip net.IP
   131  				hostname := addr
   132  				if ip = net.ParseIP(addr); ip != nil {
   133  					hostname = ""
   134  				}
   135  				if hostname != "" {
   136  					entry.Address = String{Value: addr, Quoted: token.Quoted}
   137  				} else {
   138  					// First field was an IP address.
   139  					fieldIdx++
   140  					if fieldIdx >= len(line) {
   141  						return entry, errors.WithHint(
   142  							errors.New("end-of-line before netmask specification"),
   143  							"Specify an address range in CIDR notation, or provide a separate netmask.")
   144  					}
   145  					if len(line[fieldIdx]) > 1 {
   146  						return entry, errors.New("multiple values specified for netmask")
   147  					}
   148  					maybeMask := net.ParseIP(line[fieldIdx][0].Value)
   149  					if err := checkMask(maybeMask); err != nil {
   150  						return entry, errors.Wrapf(err, "invalid IP mask \"%s\"", line[fieldIdx][0].Value)
   151  					}
   152  					// Do the address families match?
   153  					if (maybeMask.To4() == nil) != (ip.To4() == nil) {
   154  						return entry, errors.Newf("IP address and mask do not match")
   155  					}
   156  					mask := net.IPMask(maybeMask)
   157  					entry.Address = &net.IPNet{IP: ip.Mask(mask), Mask: mask}
   158  				}
   159  			}
   160  		}
   161  	} /* entryType != local */
   162  
   163  	// Get the authentication method.
   164  	fieldIdx++
   165  	if fieldIdx >= len(line) {
   166  		return entry, errors.New("end-of-line before authentication method")
   167  	}
   168  	if len(line[fieldIdx]) > 1 {
   169  		return entry, errors.WithHint(
   170  			errors.New("multiple values specified for authentication method"),
   171  			"Specify exactly one authentication method per line.")
   172  	}
   173  	entry.Method = line[fieldIdx][0]
   174  	if entry.Method.Value == "" {
   175  		return entry, errors.New("cannot use empty string as authentication method")
   176  	}
   177  
   178  	// Parse remaining arguments.
   179  	for fieldIdx++; fieldIdx < len(line); fieldIdx++ {
   180  		for _, tok := range line[fieldIdx] {
   181  			kv := strings.SplitN(tok.Value, "=", 2)
   182  			if len(kv) != 2 {
   183  				return entry, errors.Newf("authentication option not in name=value format: %s", tok.Value)
   184  			}
   185  			entry.Options = append(entry.Options, [2]string{kv[0], kv[1]})
   186  			entry.OptionQuotes = append(entry.OptionQuotes, tok.Quoted)
   187  		}
   188  	}
   189  
   190  	return entry, nil
   191  }
   192  
   193  // checkMask verifies that maybeMask is a valid IP mask, that is,
   194  // the value is all ones followed by all zeroes.
   195  func checkMask(maybeMask net.IP) error {
   196  	if maybeMask == nil {
   197  		return errors.New("netmask not in IP numeric format")
   198  	}
   199  	if ip4 := maybeMask.To4(); ip4 != nil {
   200  		maybeMask = ip4
   201  	}
   202  	i := 0
   203  	// Skip over all leading ones.
   204  	for ; i < len(maybeMask) && maybeMask[i] == '\xff'; i++ {
   205  	}
   206  	// Skip over the middle mixed ones/zeroes, if any.
   207  	if i < len(maybeMask) {
   208  		switch maybeMask[i] {
   209  		case 0xff, 0xfe, 0xfc, 0xf8, 0xf0, 0xe0, 0xc0, 0x80:
   210  			i++
   211  		}
   212  	}
   213  	// Skip over all trailing zeroes.
   214  	for ; i < len(maybeMask) && maybeMask[i] == '\x00'; i++ {
   215  	}
   216  	// If there's anything remaining, we don't have a proper mask.
   217  	if i < len(maybeMask) {
   218  		return errors.New("address is not a mask")
   219  	}
   220  	return nil
   221  }
   222  
   223  // ParseConnType parses the connection type field.
   224  func ParseConnType(s string) (ConnType, error) {
   225  	switch s {
   226  	case "local":
   227  		return ConnLocal, nil
   228  	case "host":
   229  		return ConnHostAny, nil
   230  	case "hostssl":
   231  		return ConnHostSSL, nil
   232  	case "hostnossl":
   233  		return ConnHostNoSSL, nil
   234  	}
   235  	return 0, errors.Newf("unknown connection type: %q", s)
   236  }