github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/lex/encode_test.go (about)

     1  // Copyright 2017 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package lex_test
    12  
    13  import (
    14  	"bytes"
    15  	"fmt"
    16  	"strings"
    17  	"testing"
    18  	"unicode/utf8"
    19  
    20  	"github.com/cockroachdb/cockroach/pkg/sql/lex"
    21  	"github.com/cockroachdb/cockroach/pkg/sql/parser"
    22  )
    23  
    24  func TestEncodeSQLBytes(t *testing.T) {
    25  	testEncodeSQL(t, lex.EncodeSQLBytes, false)
    26  }
    27  
    28  func TestEncodeSQLString(t *testing.T) {
    29  	testEncodeSQL(t, lex.EncodeSQLString, true)
    30  }
    31  
    32  func testEncodeSQL(t *testing.T, encode func(*bytes.Buffer, string), forceUTF8 bool) {
    33  	type entry struct{ i, j int }
    34  	seen := make(map[string]entry)
    35  	for i := 0; i < 256; i++ {
    36  		for j := 0; j < 256; j++ {
    37  			bytepair := []byte{byte(i), byte(j)}
    38  			if forceUTF8 && !utf8.Valid(bytepair) {
    39  				continue
    40  			}
    41  			stmt := testEncodeString(t, bytepair, encode)
    42  			if e, ok := seen[stmt]; ok {
    43  				t.Fatalf("duplicate entry: %s, from %v, currently at %v, %v", stmt, e, i, j)
    44  			}
    45  			seen[stmt] = entry{i, j}
    46  		}
    47  	}
    48  }
    49  
    50  func TestEncodeSQLStringSpecial(t *testing.T) {
    51  	tests := [][]byte{
    52  		// UTF8 replacement character
    53  		{0xEF, 0xBF, 0xBD},
    54  	}
    55  	for _, tc := range tests {
    56  		testEncodeString(t, tc, lex.EncodeSQLString)
    57  	}
    58  }
    59  
    60  func testEncodeString(t *testing.T, input []byte, encode func(*bytes.Buffer, string)) string {
    61  	s := string(input)
    62  	var buf bytes.Buffer
    63  	encode(&buf, s)
    64  	sql := fmt.Sprintf("SELECT %s", buf.String())
    65  	for n := 0; n < len(sql); n++ {
    66  		ch := sql[n]
    67  		if ch < 0x20 || ch >= 0x7F {
    68  			t.Fatalf("unprintable character: %v (%v): %s %v", ch, input, sql, []byte(sql))
    69  		}
    70  	}
    71  	stmts, err := parser.Parse(sql)
    72  	if err != nil {
    73  		t.Fatalf("%s: expected success, but found %s", sql, err)
    74  	}
    75  	stmt := stmts.String()
    76  	if sql != stmt {
    77  		t.Fatalf("expected %s, but found %s", sql, stmt)
    78  	}
    79  	return stmt
    80  }
    81  
    82  func BenchmarkEncodeSQLString(b *testing.B) {
    83  	str := strings.Repeat("foo", 10000)
    84  	for i := 0; i < b.N; i++ {
    85  		lex.EncodeSQLStringWithFlags(bytes.NewBuffer(nil), str, lex.EncBareStrings)
    86  	}
    87  }
    88  
    89  func TestEncodeRestrictedSQLIdent(t *testing.T) {
    90  	testCases := []struct {
    91  		input  string
    92  		output string
    93  	}{
    94  		{`foo`, `foo`},
    95  		{``, `""`},
    96  		{`3`, `"3"`},
    97  		{`foo3`, `foo3`},
    98  		{`foo"`, `"foo"""`},
    99  		{`fo"o"`, `"fo""o"""`},
   100  		{`fOo`, `"fOo"`},
   101  		{`_foo`, `_foo`},
   102  		{`-foo`, `"-foo"`},
   103  		{`select`, `"select"`},
   104  		{`integer`, `"integer"`},
   105  		// N.B. These type names are examples of type names that *should* be
   106  		// unrestricted (left out of the reserved keyword list) because they're not
   107  		// part of the sql standard type name list. This is important for Postgres
   108  		// compatibility. If you find yourself about to change this, don't - you can
   109  		// convince yourself of such by looking at the output of `quote_ident`
   110  		// against a Postgres instance.
   111  		{`int8`, `int8`},
   112  		{`date`, `date`},
   113  		{`inet`, `inet`},
   114  	}
   115  
   116  	for _, tc := range testCases {
   117  		var buf bytes.Buffer
   118  		lex.EncodeRestrictedSQLIdent(&buf, tc.input, lex.EncBareStrings)
   119  		out := buf.String()
   120  
   121  		if out != tc.output {
   122  			t.Errorf("`%s`: expected `%s`, got `%s`", tc.input, tc.output, out)
   123  		}
   124  	}
   125  }
   126  
   127  func TestByteArrayDecoding(t *testing.T) {
   128  	const (
   129  		fmtHex = lex.BytesEncodeHex
   130  		fmtEsc = lex.BytesEncodeEscape
   131  		fmtB64 = lex.BytesEncodeBase64
   132  	)
   133  	testData := []struct {
   134  		in    string
   135  		auto  bool
   136  		inFmt lex.BytesEncodeFormat
   137  		out   string
   138  		err   string
   139  	}{
   140  		{`a`, false, fmtHex, "", "encoding/hex: odd length hex string"},
   141  		{`aa`, false, fmtHex, "\xaa", ""},
   142  		{`aA`, false, fmtHex, "\xaa", ""},
   143  		{`AA`, false, fmtHex, "\xaa", ""},
   144  		{`x0`, false, fmtHex, "", "encoding/hex: invalid byte: U+0078 'x'"},
   145  		{`a\nbcd`, false, fmtEsc, "", "invalid bytea escape sequence"},
   146  		{`a\'bcd`, false, fmtEsc, "", "invalid bytea escape sequence"},
   147  		{`a\00`, false, fmtEsc, "", "bytea encoded value ends with incomplete escape sequence"},
   148  		{`a\099`, false, fmtEsc, "", "invalid bytea escape sequence"},
   149  		{`a\400`, false, fmtEsc, "", "invalid bytea escape sequence"},
   150  		{`a\777`, false, fmtEsc, "", "invalid bytea escape sequence"},
   151  		{`a'b`, false, fmtEsc, "a'b", ""},
   152  		{`a''b`, false, fmtEsc, "a''b", ""},
   153  		{`a\\b`, false, fmtEsc, "a\\b", ""},
   154  		{`a\000b`, false, fmtEsc, "a\x00b", ""},
   155  		{"a\nb", false, fmtEsc, "a\nb", ""},
   156  		{`a`, false, fmtB64, "", "illegal base64 data at input byte 0"},
   157  		{`aa=`, false, fmtB64, "", "illegal base64 data at input byte 3"},
   158  		{`AA==`, false, fmtB64, "\x00", ""},
   159  		{`/w==`, false, fmtB64, "\xff", ""},
   160  		{`AAAA`, false, fmtB64, "\x00\x00\x00", ""},
   161  		{`a`, true, 0, "a", ""},
   162  		{`\x`, true, 0, "", ""},
   163  		{`\xx`, true, 0, "", "encoding/hex: invalid byte: U+0078 'x'"},
   164  		{`\x6162`, true, 0, "ab", ""},
   165  		{`\\x6162`, true, 0, "\\x6162", ""},
   166  	}
   167  	for _, s := range testData {
   168  		t.Run(fmt.Sprintf("%s:%s", s.in, s.inFmt), func(t *testing.T) {
   169  			var dec []byte
   170  			var err error
   171  			if s.auto {
   172  				dec, err = lex.DecodeRawBytesToByteArrayAuto([]byte(s.in))
   173  			} else {
   174  				dec, err = lex.DecodeRawBytesToByteArray(s.in, s.inFmt)
   175  			}
   176  			if s.err != "" {
   177  				if err == nil {
   178  					t.Fatalf("expected err %q, got no error", s.err)
   179  				}
   180  				if s.err != err.Error() {
   181  					t.Fatalf("expected err %q, got %q", s.err, err)
   182  				}
   183  				return
   184  			}
   185  			if err != nil {
   186  				t.Fatal(err)
   187  			}
   188  			if string(dec) != s.out {
   189  				t.Fatalf("expected %q, got %q", s.out, dec)
   190  			}
   191  		})
   192  	}
   193  }
   194  
   195  func TestByteArrayEncoding(t *testing.T) {
   196  	testData := []struct {
   197  		in  string
   198  		out []string
   199  	}{
   200  		// The reference values were gathered from PostgreSQL.
   201  		{"", []string{`\x`, ``, ``}},
   202  		{"abc", []string{`\x616263`, `abc`, `YWJj`}},
   203  		{"a\nb", []string{`\x610a62`, `a\012b`, `YQpi`}},
   204  		{`a\nb`, []string{`\x615c6e62`, `a\\nb`, `YVxuYg==`}},
   205  		{"a'b", []string{`\x612762`, `a'b`, `YSdi`}},
   206  		{"a\"b", []string{`\x612262`, `a"b`, `YSJi`}},
   207  		{"a\x00b", []string{`\x610062`, `a\000b`, `YQBi`}},
   208  	}
   209  
   210  	for _, s := range testData {
   211  		t.Run(s.in, func(t *testing.T) {
   212  			for _, format := range []lex.BytesEncodeFormat{
   213  				lex.BytesEncodeHex, lex.BytesEncodeEscape, lex.BytesEncodeBase64} {
   214  				t.Run(format.String(), func(t *testing.T) {
   215  					enc := lex.EncodeByteArrayToRawBytes(s.in, format, false)
   216  
   217  					expEnc := s.out[int(format)]
   218  					if enc != expEnc {
   219  						t.Fatalf("encoded %q, expected %q", enc, expEnc)
   220  					}
   221  
   222  					if format == lex.BytesEncodeHex {
   223  						// Check that the \x also can be skipped.
   224  						enc2 := lex.EncodeByteArrayToRawBytes(s.in, format, true)
   225  						if enc[2:] != enc2 {
   226  							t.Fatal("can't skip prefix")
   227  						}
   228  						enc = enc[2:]
   229  					}
   230  
   231  					dec, err := lex.DecodeRawBytesToByteArray(enc, format)
   232  					if err != nil {
   233  						t.Fatal(err)
   234  					}
   235  					if string(dec) != s.in {
   236  						t.Fatalf("decoded %q, expected %q", string(dec), s.in)
   237  					}
   238  				})
   239  			}
   240  		})
   241  	}
   242  }