github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/lex/encode.go (about) 1 // Copyright 2012, Google Inc. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in licenses/BSD-vitess.txt. 4 5 // Portions of this file are additionally subject to the following 6 // license and copyright. 7 // 8 // Copyright 2015 The Cockroach Authors. 9 // 10 // Use of this software is governed by the Business Source License 11 // included in the file licenses/BSL.txt. 12 // 13 // As of the Change Date specified in that file, in accordance with 14 // the Business Source License, use of this software will be governed 15 // by the Apache License, Version 2.0, included in the file 16 // licenses/APL.txt. 17 18 // This code was derived from https://github.com/youtube/vitess. 19 20 package lex 21 22 import ( 23 "bytes" 24 "encoding/base64" 25 "encoding/hex" 26 "fmt" 27 "strings" 28 "unicode/utf8" 29 30 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" 31 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" 32 "github.com/cockroachdb/cockroach/pkg/util/stringencoding" 33 "github.com/cockroachdb/errors" 34 ) 35 36 var mustQuoteMap = map[byte]bool{ 37 ' ': true, 38 ',': true, 39 '{': true, 40 '}': true, 41 } 42 43 // EncodeFlags influence the formatting of strings and identifiers. 44 type EncodeFlags int 45 46 // HasFlags tests whether the given flags are set. 47 func (f EncodeFlags) HasFlags(subset EncodeFlags) bool { 48 return f&subset == subset 49 } 50 51 const ( 52 // EncNoFlags indicates nothing special should happen while encoding. 53 EncNoFlags EncodeFlags = 0 54 55 // EncBareStrings indicates that strings will be rendered without 56 // wrapping quotes if they contain no special characters. 57 EncBareStrings EncodeFlags = 1 << iota 58 59 // EncBareIdentifiers indicates that identifiers will be rendered 60 // without wrapping quotes. 61 EncBareIdentifiers 62 63 // EncFirstFreeFlagBit needs to remain unused; it is used as base 64 // bit offset for tree.FmtFlags. 65 EncFirstFreeFlagBit 66 ) 67 68 // EncodeSQLString writes a string literal to buf. All unicode and 69 // non-printable characters are escaped. 70 func EncodeSQLString(buf *bytes.Buffer, in string) { 71 EncodeSQLStringWithFlags(buf, in, EncNoFlags) 72 } 73 74 // EscapeSQLString returns an escaped SQL representation of the given 75 // string. This is suitable for safely producing a SQL string valid 76 // for input to the parser. 77 func EscapeSQLString(in string) string { 78 var buf bytes.Buffer 79 EncodeSQLString(&buf, in) 80 return buf.String() 81 } 82 83 // EncodeSQLStringWithFlags writes a string literal to buf. All 84 // unicode and non-printable characters are escaped. flags controls 85 // the output format: if encodeBareString is set, the output string 86 // will not be wrapped in quotes if the strings contains no special 87 // characters. 88 func EncodeSQLStringWithFlags(buf *bytes.Buffer, in string, flags EncodeFlags) { 89 // See http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html 90 start := 0 91 escapedString := false 92 bareStrings := flags.HasFlags(EncBareStrings) 93 // Loop through each unicode code point. 94 for i, r := range in { 95 if i < start { 96 continue 97 } 98 ch := byte(r) 99 if r >= 0x20 && r < 0x7F { 100 if mustQuoteMap[ch] { 101 // We have to quote this string - ignore bareStrings setting 102 bareStrings = false 103 } 104 if !stringencoding.NeedEscape(ch) && ch != '\'' { 105 continue 106 } 107 } 108 109 if !escapedString { 110 buf.WriteString("e'") // begin e'xxx' string 111 escapedString = true 112 } 113 buf.WriteString(in[start:i]) 114 ln := utf8.RuneLen(r) 115 if ln < 0 { 116 start = i + 1 117 } else { 118 start = i + ln 119 } 120 stringencoding.EncodeEscapedChar(buf, in, r, ch, i, '\'') 121 } 122 123 quote := !escapedString && !bareStrings 124 if quote { 125 buf.WriteByte('\'') // begin 'xxx' string if nothing was escaped 126 } 127 if start < len(in) { 128 buf.WriteString(in[start:]) 129 } 130 if escapedString || quote { 131 buf.WriteByte('\'') 132 } 133 } 134 135 // EncodeUnrestrictedSQLIdent writes the identifier in s to buf. 136 // The identifier is only quoted if the flags don't tell otherwise and 137 // the identifier contains special characters. 138 func EncodeUnrestrictedSQLIdent(buf *bytes.Buffer, s string, flags EncodeFlags) { 139 if flags.HasFlags(EncBareIdentifiers) || isBareIdentifier(s) { 140 buf.WriteString(s) 141 return 142 } 143 EncodeEscapedSQLIdent(buf, s) 144 } 145 146 // EncodeRestrictedSQLIdent writes the identifier in s to buf. The 147 // identifier is quoted if either the flags ask for it, the identifier 148 // contains special characters, or the identifier is a reserved SQL 149 // keyword. 150 func EncodeRestrictedSQLIdent(buf *bytes.Buffer, s string, flags EncodeFlags) { 151 if flags.HasFlags(EncBareIdentifiers) || (!isReservedKeyword(s) && isBareIdentifier(s)) { 152 buf.WriteString(s) 153 return 154 } 155 EncodeEscapedSQLIdent(buf, s) 156 } 157 158 // EncodeEscapedSQLIdent writes the identifier in s to buf. The 159 // identifier is always quoted. Double quotes inside the identifier 160 // are escaped. 161 func EncodeEscapedSQLIdent(buf *bytes.Buffer, s string) { 162 buf.WriteByte('"') 163 start := 0 164 for i, n := 0, len(s); i < n; i++ { 165 ch := s[i] 166 // The only character that requires escaping is a double quote. 167 if ch == '"' { 168 if start != i { 169 buf.WriteString(s[start:i]) 170 } 171 start = i + 1 172 buf.WriteByte(ch) 173 buf.WriteByte(ch) // add extra copy of ch 174 } 175 } 176 if start < len(s) { 177 buf.WriteString(s[start:]) 178 } 179 buf.WriteByte('"') 180 } 181 182 // EncodeLocaleName writes the locale identifier in s to buf. Any dash 183 // characters are mapped to underscore characters. Underscore characters do not 184 // need to be quoted, and they are considered equivalent to dash characters by 185 // the CLDR standard: http://cldr.unicode.org/. 186 func EncodeLocaleName(buf *bytes.Buffer, s string) { 187 for i, n := 0, len(s); i < n; i++ { 188 ch := s[i] 189 if ch == '-' { 190 buf.WriteByte('_') 191 } else { 192 buf.WriteByte(ch) 193 } 194 } 195 } 196 197 // EncodeSQLBytes encodes the SQL byte array in 'in' to buf, to a 198 // format suitable for re-scanning. We don't use a straightforward hex 199 // encoding here with x'...' because the result would be less 200 // compact. We are trading a little more time during the encoding to 201 // have a little less bytes on the wire. 202 func EncodeSQLBytes(buf *bytes.Buffer, in string) { 203 start := 0 204 buf.WriteString("b'") 205 // Loop over the bytes of the string (i.e., don't use range over unicode 206 // code points). 207 for i, n := 0, len(in); i < n; i++ { 208 ch := in[i] 209 if encodedChar := stringencoding.EncodeMap[ch]; encodedChar != stringencoding.DontEscape { 210 buf.WriteString(in[start:i]) 211 buf.WriteByte('\\') 212 buf.WriteByte(encodedChar) 213 start = i + 1 214 } else if ch == '\'' { 215 // We can't just fold this into stringencoding.EncodeMap because 216 // stringencoding.EncodeMap is also used for strings which 217 // aren't quoted with single-quotes 218 buf.WriteString(in[start:i]) 219 buf.WriteByte('\\') 220 buf.WriteByte(ch) 221 start = i + 1 222 } else if ch < 0x20 || ch >= 0x7F { 223 buf.WriteString(in[start:i]) 224 // Escape non-printable characters. 225 buf.Write(stringencoding.HexMap[ch]) 226 start = i + 1 227 } 228 } 229 buf.WriteString(in[start:]) 230 buf.WriteByte('\'') 231 } 232 233 // EncodeByteArrayToRawBytes converts a SQL-level byte array into raw 234 // bytes according to the encoding specification in "be". 235 // If the skipHexPrefix argument is set, the hexadecimal encoding does not 236 // prefix the output with "\x". This is suitable e.g. for the encode() 237 // built-in. 238 func EncodeByteArrayToRawBytes(data string, be BytesEncodeFormat, skipHexPrefix bool) string { 239 switch be { 240 case BytesEncodeHex: 241 head := 2 242 if skipHexPrefix { 243 head = 0 244 } 245 res := make([]byte, head+hex.EncodedLen(len(data))) 246 if !skipHexPrefix { 247 res[0] = '\\' 248 res[1] = 'x' 249 } 250 hex.Encode(res[head:], []byte(data)) 251 return string(res) 252 253 case BytesEncodeEscape: 254 // PostgreSQL does not allow all the escapes formats recognized by 255 // CockroachDB's scanner. It only recognizes octal and \\ for the 256 // backslash itself. 257 // See https://www.postgresql.org/docs/current/static/datatype-binary.html#AEN5667 258 res := make([]byte, 0, len(data)) 259 for _, c := range []byte(data) { 260 if c == '\\' { 261 res = append(res, '\\', '\\') 262 } else if c < 32 || c >= 127 { 263 // Escape the character in octal. 264 // 265 // Note: CockroachDB only supports UTF-8 for which all values 266 // below 128 are ASCII. There is no locale-dependent escaping 267 // in that case. 268 res = append(res, '\\', '0'+(c>>6), '0'+((c>>3)&7), '0'+(c&7)) 269 } else { 270 res = append(res, c) 271 } 272 } 273 return string(res) 274 275 case BytesEncodeBase64: 276 return base64.StdEncoding.EncodeToString([]byte(data)) 277 278 default: 279 panic(fmt.Sprintf("unhandled format: %s", be)) 280 } 281 } 282 283 // DecodeRawBytesToByteArray converts raw bytes to a SQL-level byte array 284 // according to the encoding specification in "be". 285 // When using the Hex format, the caller is responsible for skipping the 286 // "\x" prefix, if any. See DecodeRawBytesToByteArrayAuto() below for 287 // an alternative. 288 func DecodeRawBytesToByteArray(data string, be BytesEncodeFormat) ([]byte, error) { 289 switch be { 290 case BytesEncodeHex: 291 return hex.DecodeString(data) 292 293 case BytesEncodeEscape: 294 // PostgreSQL does not allow all the escapes formats recognized by 295 // CockroachDB's scanner. It only recognizes octal and \\ for the 296 // backslash itself. 297 // See https://www.postgresql.org/docs/current/static/datatype-binary.html#AEN5667 298 res := make([]byte, 0, len(data)) 299 for i := 0; i < len(data); i++ { 300 ch := data[i] 301 if ch != '\\' { 302 res = append(res, ch) 303 continue 304 } 305 if i >= len(data)-1 { 306 return nil, pgerror.New(pgcode.InvalidEscapeSequence, 307 "bytea encoded value ends with escape character") 308 } 309 if data[i+1] == '\\' { 310 res = append(res, '\\') 311 i++ 312 continue 313 } 314 if i+3 >= len(data) { 315 return nil, pgerror.New(pgcode.InvalidEscapeSequence, 316 "bytea encoded value ends with incomplete escape sequence") 317 } 318 b := byte(0) 319 for j := 1; j <= 3; j++ { 320 octDigit := data[i+j] 321 if octDigit < '0' || octDigit > '7' || (j == 1 && octDigit > '3') { 322 return nil, pgerror.New(pgcode.InvalidEscapeSequence, 323 "invalid bytea escape sequence") 324 } 325 b = (b << 3) | (octDigit - '0') 326 } 327 res = append(res, b) 328 i += 3 329 } 330 return res, nil 331 332 case BytesEncodeBase64: 333 return base64.StdEncoding.DecodeString(data) 334 335 default: 336 return nil, errors.AssertionFailedf("unhandled format: %s", be) 337 } 338 } 339 340 // DecodeRawBytesToByteArrayAuto detects which format to use with 341 // DecodeRawBytesToByteArray(). It only supports hex ("\x" prefix) 342 // and escape. 343 func DecodeRawBytesToByteArrayAuto(data []byte) ([]byte, error) { 344 if len(data) >= 2 && data[0] == '\\' && (data[1] == 'x' || data[1] == 'X') { 345 return DecodeRawBytesToByteArray(string(data[2:]), BytesEncodeHex) 346 } 347 return DecodeRawBytesToByteArray(string(data), BytesEncodeEscape) 348 } 349 350 // BytesEncodeFormat controls which format to use for BYTES->STRING 351 // conversions. 352 type BytesEncodeFormat int 353 354 const ( 355 // BytesEncodeHex uses the hex format: e'abc\n'::BYTES::STRING -> '\x61626312'. 356 // This is the default, for compatibility with PostgreSQL. 357 BytesEncodeHex BytesEncodeFormat = iota 358 // BytesEncodeEscape uses the escaped format: e'abc\n'::BYTES::STRING -> 'abc\012'. 359 BytesEncodeEscape 360 // BytesEncodeBase64 uses base64 encoding. 361 BytesEncodeBase64 362 ) 363 364 func (f BytesEncodeFormat) String() string { 365 switch f { 366 case BytesEncodeHex: 367 return "hex" 368 case BytesEncodeEscape: 369 return "escape" 370 case BytesEncodeBase64: 371 return "base64" 372 default: 373 return fmt.Sprintf("invalid (%d)", f) 374 } 375 } 376 377 // BytesEncodeFormatFromString converts a string into a BytesEncodeFormat. 378 func BytesEncodeFormatFromString(val string) (_ BytesEncodeFormat, ok bool) { 379 switch strings.ToUpper(val) { 380 case "HEX": 381 return BytesEncodeHex, true 382 case "ESCAPE": 383 return BytesEncodeEscape, true 384 case "BASE64": 385 return BytesEncodeBase64, true 386 default: 387 return -1, false 388 } 389 }