github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/lex/encode_test.go (about) 1 // Copyright 2017 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package lex_test 12 13 import ( 14 "bytes" 15 "fmt" 16 "strings" 17 "testing" 18 "unicode/utf8" 19 20 "github.com/cockroachdb/cockroach/pkg/sql/lex" 21 "github.com/cockroachdb/cockroach/pkg/sql/parser" 22 ) 23 24 func TestEncodeSQLBytes(t *testing.T) { 25 testEncodeSQL(t, lex.EncodeSQLBytes, false) 26 } 27 28 func TestEncodeSQLString(t *testing.T) { 29 testEncodeSQL(t, lex.EncodeSQLString, true) 30 } 31 32 func testEncodeSQL(t *testing.T, encode func(*bytes.Buffer, string), forceUTF8 bool) { 33 type entry struct{ i, j int } 34 seen := make(map[string]entry) 35 for i := 0; i < 256; i++ { 36 for j := 0; j < 256; j++ { 37 bytepair := []byte{byte(i), byte(j)} 38 if forceUTF8 && !utf8.Valid(bytepair) { 39 continue 40 } 41 stmt := testEncodeString(t, bytepair, encode) 42 if e, ok := seen[stmt]; ok { 43 t.Fatalf("duplicate entry: %s, from %v, currently at %v, %v", stmt, e, i, j) 44 } 45 seen[stmt] = entry{i, j} 46 } 47 } 48 } 49 50 func TestEncodeSQLStringSpecial(t *testing.T) { 51 tests := [][]byte{ 52 // UTF8 replacement character 53 {0xEF, 0xBF, 0xBD}, 54 } 55 for _, tc := range tests { 56 testEncodeString(t, tc, lex.EncodeSQLString) 57 } 58 } 59 60 func testEncodeString(t *testing.T, input []byte, encode func(*bytes.Buffer, string)) string { 61 s := string(input) 62 var buf bytes.Buffer 63 encode(&buf, s) 64 sql := fmt.Sprintf("SELECT %s", buf.String()) 65 for n := 0; n < len(sql); n++ { 66 ch := sql[n] 67 if ch < 0x20 || ch >= 0x7F { 68 t.Fatalf("unprintable character: %v (%v): %s %v", ch, input, sql, []byte(sql)) 69 } 70 } 71 stmts, err := parser.Parse(sql) 72 if err != nil { 73 t.Fatalf("%s: expected success, but found %s", sql, err) 74 } 75 stmt := stmts.String() 76 if sql != stmt { 77 t.Fatalf("expected %s, but found %s", sql, stmt) 78 } 79 return stmt 80 } 81 82 func BenchmarkEncodeSQLString(b *testing.B) { 83 str := strings.Repeat("foo", 10000) 84 for i := 0; i < b.N; i++ { 85 lex.EncodeSQLStringWithFlags(bytes.NewBuffer(nil), str, lex.EncBareStrings) 86 } 87 } 88 89 func TestEncodeRestrictedSQLIdent(t *testing.T) { 90 testCases := []struct { 91 input string 92 output string 93 }{ 94 {`foo`, `foo`}, 95 {``, `""`}, 96 {`3`, `"3"`}, 97 {`foo3`, `foo3`}, 98 {`foo"`, `"foo"""`}, 99 {`fo"o"`, `"fo""o"""`}, 100 {`fOo`, `"fOo"`}, 101 {`_foo`, `_foo`}, 102 {`-foo`, `"-foo"`}, 103 {`select`, `"select"`}, 104 {`integer`, `"integer"`}, 105 // N.B. These type names are examples of type names that *should* be 106 // unrestricted (left out of the reserved keyword list) because they're not 107 // part of the sql standard type name list. This is important for Postgres 108 // compatibility. If you find yourself about to change this, don't - you can 109 // convince yourself of such by looking at the output of `quote_ident` 110 // against a Postgres instance. 111 {`int8`, `int8`}, 112 {`date`, `date`}, 113 {`inet`, `inet`}, 114 } 115 116 for _, tc := range testCases { 117 var buf bytes.Buffer 118 lex.EncodeRestrictedSQLIdent(&buf, tc.input, lex.EncBareStrings) 119 out := buf.String() 120 121 if out != tc.output { 122 t.Errorf("`%s`: expected `%s`, got `%s`", tc.input, tc.output, out) 123 } 124 } 125 } 126 127 func TestByteArrayDecoding(t *testing.T) { 128 const ( 129 fmtHex = lex.BytesEncodeHex 130 fmtEsc = lex.BytesEncodeEscape 131 fmtB64 = lex.BytesEncodeBase64 132 ) 133 testData := []struct { 134 in string 135 auto bool 136 inFmt lex.BytesEncodeFormat 137 out string 138 err string 139 }{ 140 {`a`, false, fmtHex, "", "encoding/hex: odd length hex string"}, 141 {`aa`, false, fmtHex, "\xaa", ""}, 142 {`aA`, false, fmtHex, "\xaa", ""}, 143 {`AA`, false, fmtHex, "\xaa", ""}, 144 {`x0`, false, fmtHex, "", "encoding/hex: invalid byte: U+0078 'x'"}, 145 {`a\nbcd`, false, fmtEsc, "", "invalid bytea escape sequence"}, 146 {`a\'bcd`, false, fmtEsc, "", "invalid bytea escape sequence"}, 147 {`a\00`, false, fmtEsc, "", "bytea encoded value ends with incomplete escape sequence"}, 148 {`a\099`, false, fmtEsc, "", "invalid bytea escape sequence"}, 149 {`a\400`, false, fmtEsc, "", "invalid bytea escape sequence"}, 150 {`a\777`, false, fmtEsc, "", "invalid bytea escape sequence"}, 151 {`a'b`, false, fmtEsc, "a'b", ""}, 152 {`a''b`, false, fmtEsc, "a''b", ""}, 153 {`a\\b`, false, fmtEsc, "a\\b", ""}, 154 {`a\000b`, false, fmtEsc, "a\x00b", ""}, 155 {"a\nb", false, fmtEsc, "a\nb", ""}, 156 {`a`, false, fmtB64, "", "illegal base64 data at input byte 0"}, 157 {`aa=`, false, fmtB64, "", "illegal base64 data at input byte 3"}, 158 {`AA==`, false, fmtB64, "\x00", ""}, 159 {`/w==`, false, fmtB64, "\xff", ""}, 160 {`AAAA`, false, fmtB64, "\x00\x00\x00", ""}, 161 {`a`, true, 0, "a", ""}, 162 {`\x`, true, 0, "", ""}, 163 {`\xx`, true, 0, "", "encoding/hex: invalid byte: U+0078 'x'"}, 164 {`\x6162`, true, 0, "ab", ""}, 165 {`\\x6162`, true, 0, "\\x6162", ""}, 166 } 167 for _, s := range testData { 168 t.Run(fmt.Sprintf("%s:%s", s.in, s.inFmt), func(t *testing.T) { 169 var dec []byte 170 var err error 171 if s.auto { 172 dec, err = lex.DecodeRawBytesToByteArrayAuto([]byte(s.in)) 173 } else { 174 dec, err = lex.DecodeRawBytesToByteArray(s.in, s.inFmt) 175 } 176 if s.err != "" { 177 if err == nil { 178 t.Fatalf("expected err %q, got no error", s.err) 179 } 180 if s.err != err.Error() { 181 t.Fatalf("expected err %q, got %q", s.err, err) 182 } 183 return 184 } 185 if err != nil { 186 t.Fatal(err) 187 } 188 if string(dec) != s.out { 189 t.Fatalf("expected %q, got %q", s.out, dec) 190 } 191 }) 192 } 193 } 194 195 func TestByteArrayEncoding(t *testing.T) { 196 testData := []struct { 197 in string 198 out []string 199 }{ 200 // The reference values were gathered from PostgreSQL. 201 {"", []string{`\x`, ``, ``}}, 202 {"abc", []string{`\x616263`, `abc`, `YWJj`}}, 203 {"a\nb", []string{`\x610a62`, `a\012b`, `YQpi`}}, 204 {`a\nb`, []string{`\x615c6e62`, `a\\nb`, `YVxuYg==`}}, 205 {"a'b", []string{`\x612762`, `a'b`, `YSdi`}}, 206 {"a\"b", []string{`\x612262`, `a"b`, `YSJi`}}, 207 {"a\x00b", []string{`\x610062`, `a\000b`, `YQBi`}}, 208 } 209 210 for _, s := range testData { 211 t.Run(s.in, func(t *testing.T) { 212 for _, format := range []lex.BytesEncodeFormat{ 213 lex.BytesEncodeHex, lex.BytesEncodeEscape, lex.BytesEncodeBase64} { 214 t.Run(format.String(), func(t *testing.T) { 215 enc := lex.EncodeByteArrayToRawBytes(s.in, format, false) 216 217 expEnc := s.out[int(format)] 218 if enc != expEnc { 219 t.Fatalf("encoded %q, expected %q", enc, expEnc) 220 } 221 222 if format == lex.BytesEncodeHex { 223 // Check that the \x also can be skipped. 224 enc2 := lex.EncodeByteArrayToRawBytes(s.in, format, true) 225 if enc[2:] != enc2 { 226 t.Fatal("can't skip prefix") 227 } 228 enc = enc[2:] 229 } 230 231 dec, err := lex.DecodeRawBytesToByteArray(enc, format) 232 if err != nil { 233 t.Fatal(err) 234 } 235 if string(dec) != s.in { 236 t.Fatalf("decoded %q, expected %q", string(dec), s.in) 237 } 238 }) 239 } 240 }) 241 } 242 }