github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/lex/encode.go (about)

     1  // Copyright 2012, Google Inc. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in licenses/BSD-vitess.txt.
     4  
     5  // Portions of this file are additionally subject to the following
     6  // license and copyright.
     7  //
     8  // Copyright 2015 The Cockroach Authors.
     9  //
    10  // Use of this software is governed by the Business Source License
    11  // included in the file licenses/BSL.txt.
    12  //
    13  // As of the Change Date specified in that file, in accordance with
    14  // the Business Source License, use of this software will be governed
    15  // by the Apache License, Version 2.0, included in the file
    16  // licenses/APL.txt.
    17  
    18  // This code was derived from https://github.com/youtube/vitess.
    19  
    20  package lex
    21  
    22  import (
    23  	"bytes"
    24  	"encoding/base64"
    25  	"encoding/hex"
    26  	"fmt"
    27  	"strings"
    28  	"unicode/utf8"
    29  
    30  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    31  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    32  	"github.com/cockroachdb/cockroach/pkg/util/stringencoding"
    33  	"github.com/cockroachdb/errors"
    34  )
    35  
    36  var mustQuoteMap = map[byte]bool{
    37  	' ': true,
    38  	',': true,
    39  	'{': true,
    40  	'}': true,
    41  }
    42  
    43  // EncodeFlags influence the formatting of strings and identifiers.
    44  type EncodeFlags int
    45  
    46  // HasFlags tests whether the given flags are set.
    47  func (f EncodeFlags) HasFlags(subset EncodeFlags) bool {
    48  	return f&subset == subset
    49  }
    50  
    51  const (
    52  	// EncNoFlags indicates nothing special should happen while encoding.
    53  	EncNoFlags EncodeFlags = 0
    54  
    55  	// EncBareStrings indicates that strings will be rendered without
    56  	// wrapping quotes if they contain no special characters.
    57  	EncBareStrings EncodeFlags = 1 << iota
    58  
    59  	// EncBareIdentifiers indicates that identifiers will be rendered
    60  	// without wrapping quotes.
    61  	EncBareIdentifiers
    62  
    63  	// EncFirstFreeFlagBit needs to remain unused; it is used as base
    64  	// bit offset for tree.FmtFlags.
    65  	EncFirstFreeFlagBit
    66  )
    67  
    68  // EncodeSQLString writes a string literal to buf. All unicode and
    69  // non-printable characters are escaped.
    70  func EncodeSQLString(buf *bytes.Buffer, in string) {
    71  	EncodeSQLStringWithFlags(buf, in, EncNoFlags)
    72  }
    73  
    74  // EscapeSQLString returns an escaped SQL representation of the given
    75  // string. This is suitable for safely producing a SQL string valid
    76  // for input to the parser.
    77  func EscapeSQLString(in string) string {
    78  	var buf bytes.Buffer
    79  	EncodeSQLString(&buf, in)
    80  	return buf.String()
    81  }
    82  
    83  // EncodeSQLStringWithFlags writes a string literal to buf. All
    84  // unicode and non-printable characters are escaped. flags controls
    85  // the output format: if encodeBareString is set, the output string
    86  // will not be wrapped in quotes if the strings contains no special
    87  // characters.
    88  func EncodeSQLStringWithFlags(buf *bytes.Buffer, in string, flags EncodeFlags) {
    89  	// See http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html
    90  	start := 0
    91  	escapedString := false
    92  	bareStrings := flags.HasFlags(EncBareStrings)
    93  	// Loop through each unicode code point.
    94  	for i, r := range in {
    95  		if i < start {
    96  			continue
    97  		}
    98  		ch := byte(r)
    99  		if r >= 0x20 && r < 0x7F {
   100  			if mustQuoteMap[ch] {
   101  				// We have to quote this string - ignore bareStrings setting
   102  				bareStrings = false
   103  			}
   104  			if !stringencoding.NeedEscape(ch) && ch != '\'' {
   105  				continue
   106  			}
   107  		}
   108  
   109  		if !escapedString {
   110  			buf.WriteString("e'") // begin e'xxx' string
   111  			escapedString = true
   112  		}
   113  		buf.WriteString(in[start:i])
   114  		ln := utf8.RuneLen(r)
   115  		if ln < 0 {
   116  			start = i + 1
   117  		} else {
   118  			start = i + ln
   119  		}
   120  		stringencoding.EncodeEscapedChar(buf, in, r, ch, i, '\'')
   121  	}
   122  
   123  	quote := !escapedString && !bareStrings
   124  	if quote {
   125  		buf.WriteByte('\'') // begin 'xxx' string if nothing was escaped
   126  	}
   127  	if start < len(in) {
   128  		buf.WriteString(in[start:])
   129  	}
   130  	if escapedString || quote {
   131  		buf.WriteByte('\'')
   132  	}
   133  }
   134  
   135  // EncodeUnrestrictedSQLIdent writes the identifier in s to buf.
   136  // The identifier is only quoted if the flags don't tell otherwise and
   137  // the identifier contains special characters.
   138  func EncodeUnrestrictedSQLIdent(buf *bytes.Buffer, s string, flags EncodeFlags) {
   139  	if flags.HasFlags(EncBareIdentifiers) || isBareIdentifier(s) {
   140  		buf.WriteString(s)
   141  		return
   142  	}
   143  	EncodeEscapedSQLIdent(buf, s)
   144  }
   145  
   146  // EncodeRestrictedSQLIdent writes the identifier in s to buf. The
   147  // identifier is quoted if either the flags ask for it, the identifier
   148  // contains special characters, or the identifier is a reserved SQL
   149  // keyword.
   150  func EncodeRestrictedSQLIdent(buf *bytes.Buffer, s string, flags EncodeFlags) {
   151  	if flags.HasFlags(EncBareIdentifiers) || (!isReservedKeyword(s) && isBareIdentifier(s)) {
   152  		buf.WriteString(s)
   153  		return
   154  	}
   155  	EncodeEscapedSQLIdent(buf, s)
   156  }
   157  
   158  // EncodeEscapedSQLIdent writes the identifier in s to buf. The
   159  // identifier is always quoted. Double quotes inside the identifier
   160  // are escaped.
   161  func EncodeEscapedSQLIdent(buf *bytes.Buffer, s string) {
   162  	buf.WriteByte('"')
   163  	start := 0
   164  	for i, n := 0, len(s); i < n; i++ {
   165  		ch := s[i]
   166  		// The only character that requires escaping is a double quote.
   167  		if ch == '"' {
   168  			if start != i {
   169  				buf.WriteString(s[start:i])
   170  			}
   171  			start = i + 1
   172  			buf.WriteByte(ch)
   173  			buf.WriteByte(ch) // add extra copy of ch
   174  		}
   175  	}
   176  	if start < len(s) {
   177  		buf.WriteString(s[start:])
   178  	}
   179  	buf.WriteByte('"')
   180  }
   181  
   182  // EncodeLocaleName writes the locale identifier in s to buf. Any dash
   183  // characters are mapped to underscore characters. Underscore characters do not
   184  // need to be quoted, and they are considered equivalent to dash characters by
   185  // the CLDR standard: http://cldr.unicode.org/.
   186  func EncodeLocaleName(buf *bytes.Buffer, s string) {
   187  	for i, n := 0, len(s); i < n; i++ {
   188  		ch := s[i]
   189  		if ch == '-' {
   190  			buf.WriteByte('_')
   191  		} else {
   192  			buf.WriteByte(ch)
   193  		}
   194  	}
   195  }
   196  
   197  // EncodeSQLBytes encodes the SQL byte array in 'in' to buf, to a
   198  // format suitable for re-scanning. We don't use a straightforward hex
   199  // encoding here with x'...'  because the result would be less
   200  // compact. We are trading a little more time during the encoding to
   201  // have a little less bytes on the wire.
   202  func EncodeSQLBytes(buf *bytes.Buffer, in string) {
   203  	start := 0
   204  	buf.WriteString("b'")
   205  	// Loop over the bytes of the string (i.e., don't use range over unicode
   206  	// code points).
   207  	for i, n := 0, len(in); i < n; i++ {
   208  		ch := in[i]
   209  		if encodedChar := stringencoding.EncodeMap[ch]; encodedChar != stringencoding.DontEscape {
   210  			buf.WriteString(in[start:i])
   211  			buf.WriteByte('\\')
   212  			buf.WriteByte(encodedChar)
   213  			start = i + 1
   214  		} else if ch == '\'' {
   215  			// We can't just fold this into stringencoding.EncodeMap because
   216  			// stringencoding.EncodeMap is also used for strings which
   217  			// aren't quoted with single-quotes
   218  			buf.WriteString(in[start:i])
   219  			buf.WriteByte('\\')
   220  			buf.WriteByte(ch)
   221  			start = i + 1
   222  		} else if ch < 0x20 || ch >= 0x7F {
   223  			buf.WriteString(in[start:i])
   224  			// Escape non-printable characters.
   225  			buf.Write(stringencoding.HexMap[ch])
   226  			start = i + 1
   227  		}
   228  	}
   229  	buf.WriteString(in[start:])
   230  	buf.WriteByte('\'')
   231  }
   232  
   233  // EncodeByteArrayToRawBytes converts a SQL-level byte array into raw
   234  // bytes according to the encoding specification in "be".
   235  // If the skipHexPrefix argument is set, the hexadecimal encoding does not
   236  // prefix the output with "\x". This is suitable e.g. for the encode()
   237  // built-in.
   238  func EncodeByteArrayToRawBytes(data string, be BytesEncodeFormat, skipHexPrefix bool) string {
   239  	switch be {
   240  	case BytesEncodeHex:
   241  		head := 2
   242  		if skipHexPrefix {
   243  			head = 0
   244  		}
   245  		res := make([]byte, head+hex.EncodedLen(len(data)))
   246  		if !skipHexPrefix {
   247  			res[0] = '\\'
   248  			res[1] = 'x'
   249  		}
   250  		hex.Encode(res[head:], []byte(data))
   251  		return string(res)
   252  
   253  	case BytesEncodeEscape:
   254  		// PostgreSQL does not allow all the escapes formats recognized by
   255  		// CockroachDB's scanner. It only recognizes octal and \\ for the
   256  		// backslash itself.
   257  		// See https://www.postgresql.org/docs/current/static/datatype-binary.html#AEN5667
   258  		res := make([]byte, 0, len(data))
   259  		for _, c := range []byte(data) {
   260  			if c == '\\' {
   261  				res = append(res, '\\', '\\')
   262  			} else if c < 32 || c >= 127 {
   263  				// Escape the character in octal.
   264  				//
   265  				// Note: CockroachDB only supports UTF-8 for which all values
   266  				// below 128 are ASCII. There is no locale-dependent escaping
   267  				// in that case.
   268  				res = append(res, '\\', '0'+(c>>6), '0'+((c>>3)&7), '0'+(c&7))
   269  			} else {
   270  				res = append(res, c)
   271  			}
   272  		}
   273  		return string(res)
   274  
   275  	case BytesEncodeBase64:
   276  		return base64.StdEncoding.EncodeToString([]byte(data))
   277  
   278  	default:
   279  		panic(fmt.Sprintf("unhandled format: %s", be))
   280  	}
   281  }
   282  
   283  // DecodeRawBytesToByteArray converts raw bytes to a SQL-level byte array
   284  // according to the encoding specification in "be".
   285  // When using the Hex format, the caller is responsible for skipping the
   286  // "\x" prefix, if any. See DecodeRawBytesToByteArrayAuto() below for
   287  // an alternative.
   288  func DecodeRawBytesToByteArray(data string, be BytesEncodeFormat) ([]byte, error) {
   289  	switch be {
   290  	case BytesEncodeHex:
   291  		return hex.DecodeString(data)
   292  
   293  	case BytesEncodeEscape:
   294  		// PostgreSQL does not allow all the escapes formats recognized by
   295  		// CockroachDB's scanner. It only recognizes octal and \\ for the
   296  		// backslash itself.
   297  		// See https://www.postgresql.org/docs/current/static/datatype-binary.html#AEN5667
   298  		res := make([]byte, 0, len(data))
   299  		for i := 0; i < len(data); i++ {
   300  			ch := data[i]
   301  			if ch != '\\' {
   302  				res = append(res, ch)
   303  				continue
   304  			}
   305  			if i >= len(data)-1 {
   306  				return nil, pgerror.New(pgcode.InvalidEscapeSequence,
   307  					"bytea encoded value ends with escape character")
   308  			}
   309  			if data[i+1] == '\\' {
   310  				res = append(res, '\\')
   311  				i++
   312  				continue
   313  			}
   314  			if i+3 >= len(data) {
   315  				return nil, pgerror.New(pgcode.InvalidEscapeSequence,
   316  					"bytea encoded value ends with incomplete escape sequence")
   317  			}
   318  			b := byte(0)
   319  			for j := 1; j <= 3; j++ {
   320  				octDigit := data[i+j]
   321  				if octDigit < '0' || octDigit > '7' || (j == 1 && octDigit > '3') {
   322  					return nil, pgerror.New(pgcode.InvalidEscapeSequence,
   323  						"invalid bytea escape sequence")
   324  				}
   325  				b = (b << 3) | (octDigit - '0')
   326  			}
   327  			res = append(res, b)
   328  			i += 3
   329  		}
   330  		return res, nil
   331  
   332  	case BytesEncodeBase64:
   333  		return base64.StdEncoding.DecodeString(data)
   334  
   335  	default:
   336  		return nil, errors.AssertionFailedf("unhandled format: %s", be)
   337  	}
   338  }
   339  
   340  // DecodeRawBytesToByteArrayAuto detects which format to use with
   341  // DecodeRawBytesToByteArray(). It only supports hex ("\x" prefix)
   342  // and escape.
   343  func DecodeRawBytesToByteArrayAuto(data []byte) ([]byte, error) {
   344  	if len(data) >= 2 && data[0] == '\\' && (data[1] == 'x' || data[1] == 'X') {
   345  		return DecodeRawBytesToByteArray(string(data[2:]), BytesEncodeHex)
   346  	}
   347  	return DecodeRawBytesToByteArray(string(data), BytesEncodeEscape)
   348  }
   349  
   350  // BytesEncodeFormat controls which format to use for BYTES->STRING
   351  // conversions.
   352  type BytesEncodeFormat int
   353  
   354  const (
   355  	// BytesEncodeHex uses the hex format: e'abc\n'::BYTES::STRING -> '\x61626312'.
   356  	// This is the default, for compatibility with PostgreSQL.
   357  	BytesEncodeHex BytesEncodeFormat = iota
   358  	// BytesEncodeEscape uses the escaped format: e'abc\n'::BYTES::STRING -> 'abc\012'.
   359  	BytesEncodeEscape
   360  	// BytesEncodeBase64 uses base64 encoding.
   361  	BytesEncodeBase64
   362  )
   363  
   364  func (f BytesEncodeFormat) String() string {
   365  	switch f {
   366  	case BytesEncodeHex:
   367  		return "hex"
   368  	case BytesEncodeEscape:
   369  		return "escape"
   370  	case BytesEncodeBase64:
   371  		return "base64"
   372  	default:
   373  		return fmt.Sprintf("invalid (%d)", f)
   374  	}
   375  }
   376  
   377  // BytesEncodeFormatFromString converts a string into a BytesEncodeFormat.
   378  func BytesEncodeFormatFromString(val string) (_ BytesEncodeFormat, ok bool) {
   379  	switch strings.ToUpper(val) {
   380  	case "HEX":
   381  		return BytesEncodeHex, true
   382  	case "ESCAPE":
   383  		return BytesEncodeEscape, true
   384  	case "BASE64":
   385  		return BytesEncodeBase64, true
   386  	default:
   387  		return -1, false
   388  	}
   389  }