github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/lex/encode.go (about)

     1  // Copyright 2012, Google Inc. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in licenses/BSD-vitess.txt.
     4  
     5  // Portions of this file are additionally subject to the following
     6  // license and copyright.
     7  //
     8  // Copyright 2015 The Cockroach Authors.
     9  //
    10  // Use of this software is governed by the Business Source License
    11  // included in the file licenses/BSL.txt.
    12  //
    13  // As of the Change Date specified in that file, in accordance with
    14  // the Business Source License, use of this software will be governed
    15  // by the Apache License, Version 2.0, included in the file
    16  // licenses/APL.txt.
    17  
    18  // This code was derived from https://github.com/youtube/vitess.
    19  
    20  package lex
    21  
    22  import (
    23  	"bytes"
    24  	"encoding/base64"
    25  	"encoding/hex"
    26  	"fmt"
    27  	"strings"
    28  	"unicode"
    29  
    30  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgcode"
    31  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgerror"
    32  	"github.com/cockroachdb/errors"
    33  	"golang.org/x/text/language"
    34  )
    35  
    36  // NormalizeLocaleName returns a normalized locale identifier based on s. The
    37  // case of the locale is normalized and any dash characters are mapped to
    38  // underscore characters.
    39  func NormalizeLocaleName(s string) string {
    40  	b := bytes.NewBuffer(make([]byte, 0, len(s)))
    41  	EncodeLocaleName(b, s)
    42  	return b.String()
    43  }
    44  
    45  // EncodeLocaleName writes the locale identifier in s to buf. Any dash
    46  // characters are mapped to underscore characters. Underscore characters do not
    47  // need to be quoted, and they are considered equivalent to dash characters by
    48  // the CLDR standard: http://cldr.unicode.org/.
    49  func EncodeLocaleName(buf *bytes.Buffer, s string) {
    50  	// If possible, try to normalize the case of the locale name.
    51  	if normalized, err := language.Parse(s); err == nil {
    52  		s = normalized.String()
    53  	}
    54  	for i, n := 0, len(s); i < n; i++ {
    55  		ch := s[i]
    56  		if ch == '-' {
    57  			buf.WriteByte('_')
    58  		} else {
    59  			buf.WriteByte(ch)
    60  		}
    61  	}
    62  }
    63  
    64  // LocaleNamesAreEqual checks for equality of two locale names. The comparison
    65  // is case-insensitive and treats '-' and '_' as the same.
    66  func LocaleNamesAreEqual(a, b string) bool {
    67  	if a == b {
    68  		return true
    69  	}
    70  	if len(a) != len(b) {
    71  		return false
    72  	}
    73  	for i, n := 0, len(a); i < n; i++ {
    74  		ai, bi := a[i], b[i]
    75  		if ai == bi {
    76  			continue
    77  		}
    78  		if ai == '-' && bi == '_' {
    79  			continue
    80  		}
    81  		if ai == '_' && bi == '-' {
    82  			continue
    83  		}
    84  		if unicode.ToLower(rune(ai)) != unicode.ToLower(rune(bi)) {
    85  			return false
    86  		}
    87  	}
    88  	return true
    89  }
    90  
    91  // EncodeByteArrayToRawBytes converts a SQL-level byte array into raw
    92  // bytes according to the encoding specification in "be".
    93  // If the skipHexPrefix argument is set, the hexadecimal encoding does not
    94  // prefix the output with "\x". This is suitable e.g. for the encode()
    95  // built-in.
    96  func EncodeByteArrayToRawBytes(data string, be BytesEncodeFormat, skipHexPrefix bool) string {
    97  	switch be {
    98  	case BytesEncodeHex:
    99  		head := 2
   100  		if skipHexPrefix {
   101  			head = 0
   102  		}
   103  		res := make([]byte, head+hex.EncodedLen(len(data)))
   104  		if !skipHexPrefix {
   105  			res[0] = '\\'
   106  			res[1] = 'x'
   107  		}
   108  		hex.Encode(res[head:], []byte(data))
   109  		return string(res)
   110  
   111  	case BytesEncodeEscape:
   112  		// PostgreSQL does not allow all the escapes formats recognized by
   113  		// CockroachDB's scanner. It only recognizes octal and \\ for the
   114  		// backslash itself.
   115  		// See https://www.postgresql.org/docs/current/static/datatype-binary.html#AEN5667
   116  		res := make([]byte, 0, len(data))
   117  		for _, c := range []byte(data) {
   118  			if c == '\\' {
   119  				res = append(res, '\\', '\\')
   120  			} else if c < 32 || c >= 127 {
   121  				// Escape the character in octal.
   122  				//
   123  				// Note: CockroachDB only supports UTF-8 for which all values
   124  				// below 128 are ASCII. There is no locale-dependent escaping
   125  				// in that case.
   126  				res = append(res, '\\', '0'+(c>>6), '0'+((c>>3)&7), '0'+(c&7))
   127  			} else {
   128  				res = append(res, c)
   129  			}
   130  		}
   131  		return string(res)
   132  
   133  	case BytesEncodeBase64:
   134  		return base64.StdEncoding.EncodeToString([]byte(data))
   135  
   136  	default:
   137  		panic(errors.AssertionFailedf("unhandled format: %s", be))
   138  	}
   139  }
   140  
   141  // DecodeRawBytesToByteArray converts raw bytes to a SQL-level byte array
   142  // according to the encoding specification in "be".
   143  // When using the Hex format, the caller is responsible for skipping the
   144  // "\x" prefix, if any. See DecodeRawBytesToByteArrayAuto() below for
   145  // an alternative. If no conversion is necessary the input is returned,
   146  // callers should not assume a copy is made.
   147  func DecodeRawBytesToByteArray(data []byte, be BytesEncodeFormat) ([]byte, error) {
   148  	switch be {
   149  	case BytesEncodeHex:
   150  		res := make([]byte, hex.DecodedLen(len(data)))
   151  		n, err := hex.Decode(res, data)
   152  		return res[:n], err
   153  
   154  	case BytesEncodeEscape:
   155  		// PostgreSQL does not allow all the escapes formats recognized by
   156  		// CockroachDB's scanner. It only recognizes octal and \\ for the
   157  		// backslash itself.
   158  		// See https://www.postgresql.org/docs/current/static/datatype-binary.html#AEN5667
   159  		res := data
   160  		copied := false
   161  		for i := 0; i < len(data); i++ {
   162  			ch := data[i]
   163  			if ch != '\\' {
   164  				if copied {
   165  					res = append(res, ch)
   166  				}
   167  				continue
   168  			}
   169  			if i >= len(data)-1 {
   170  				return nil, pgerror.New(pgcode.InvalidEscapeSequence,
   171  					"bytea encoded value ends with escape character")
   172  			}
   173  			if !copied {
   174  				res = make([]byte, 0, len(data))
   175  				res = append(res, data[:i]...)
   176  				copied = true
   177  			}
   178  			if data[i+1] == '\\' {
   179  				res = append(res, '\\')
   180  				i++
   181  				continue
   182  			}
   183  			if i+3 >= len(data) {
   184  				return nil, pgerror.New(pgcode.InvalidEscapeSequence,
   185  					"bytea encoded value ends with incomplete escape sequence")
   186  			}
   187  			b := byte(0)
   188  			for j := 1; j <= 3; j++ {
   189  				octDigit := data[i+j]
   190  				if octDigit < '0' || octDigit > '7' || (j == 1 && octDigit > '3') {
   191  					return nil, pgerror.New(pgcode.InvalidEscapeSequence,
   192  						"invalid bytea escape sequence")
   193  				}
   194  				b = (b << 3) | (octDigit - '0')
   195  			}
   196  			res = append(res, b)
   197  			i += 3
   198  		}
   199  		return res, nil
   200  
   201  	case BytesEncodeBase64:
   202  		res := make([]byte, base64.StdEncoding.DecodedLen(len(data)))
   203  		n, err := base64.StdEncoding.Decode(res, data)
   204  		return res[:n], err
   205  
   206  	default:
   207  		return nil, errors.AssertionFailedf("unhandled format: %s", be)
   208  	}
   209  }
   210  
   211  // DecodeRawBytesToByteArrayAuto detects which format to use with
   212  // DecodeRawBytesToByteArray(). It only supports hex ("\x" prefix)
   213  // and escape.
   214  func DecodeRawBytesToByteArrayAuto(data []byte) ([]byte, error) {
   215  	if len(data) >= 2 && data[0] == '\\' && (data[1] == 'x' || data[1] == 'X') {
   216  		return DecodeRawBytesToByteArray(data[2:], BytesEncodeHex)
   217  	}
   218  	return DecodeRawBytesToByteArray(data, BytesEncodeEscape)
   219  }
   220  
   221  func (f BytesEncodeFormat) String() string {
   222  	switch f {
   223  	case BytesEncodeHex:
   224  		return "hex"
   225  	case BytesEncodeEscape:
   226  		return "escape"
   227  	case BytesEncodeBase64:
   228  		return "base64"
   229  	default:
   230  		return fmt.Sprintf("invalid (%d)", f)
   231  	}
   232  }
   233  
   234  // BytesEncodeFormatFromString converts a string into a BytesEncodeFormat.
   235  func BytesEncodeFormatFromString(val string) (_ BytesEncodeFormat, ok bool) {
   236  	switch strings.ToUpper(val) {
   237  	case "HEX":
   238  		return BytesEncodeHex, true
   239  	case "ESCAPE":
   240  		return BytesEncodeEscape, true
   241  	case "BASE64":
   242  		return BytesEncodeBase64, true
   243  	default:
   244  		return -1, false
   245  	}
   246  }