vitess.io/vitess@v0.16.2/go/vt/sqlparser/normalizer_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"math/rand"
    23  	"reflect"
    24  	"regexp"
    25  	"strconv"
    26  	"testing"
    27  
    28  	"github.com/stretchr/testify/assert"
    29  
    30  	"github.com/stretchr/testify/require"
    31  
    32  	"vitess.io/vitess/go/sqltypes"
    33  	querypb "vitess.io/vitess/go/vt/proto/query"
    34  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    35  	"vitess.io/vitess/go/vt/vterrors"
    36  )
    37  
    38  func TestNormalize(t *testing.T) {
    39  	prefix := "bv"
    40  	testcases := []struct {
    41  		in      string
    42  		outstmt string
    43  		outbv   map[string]*querypb.BindVariable
    44  	}{{
    45  		// str val
    46  		in:      "select * from t where foobar = 'aa'",
    47  		outstmt: "select * from t where foobar = :foobar",
    48  		outbv: map[string]*querypb.BindVariable{
    49  			"foobar": sqltypes.StringBindVariable("aa"),
    50  		},
    51  	}, {
    52  		// placeholder
    53  		in:      "select * from t where col=?",
    54  		outstmt: "select * from t where col = :v1",
    55  		outbv:   map[string]*querypb.BindVariable{},
    56  	}, {
    57  		// qualified table name
    58  		in:      "select * from `t` where col=?",
    59  		outstmt: "select * from t where col = :v1",
    60  		outbv:   map[string]*querypb.BindVariable{},
    61  	}, {
    62  		// str val in select
    63  		in:      "select 'aa' from t",
    64  		outstmt: "select :bv1 from t",
    65  		outbv: map[string]*querypb.BindVariable{
    66  			"bv1": sqltypes.StringBindVariable("aa"),
    67  		},
    68  	}, {
    69  		// int val
    70  		in:      "select * from t where foobar = 1",
    71  		outstmt: "select * from t where foobar = :foobar",
    72  		outbv: map[string]*querypb.BindVariable{
    73  			"foobar": sqltypes.Int64BindVariable(1),
    74  		},
    75  	}, {
    76  		// float val
    77  		in:      "select * from t where foobar = 1.2",
    78  		outstmt: "select * from t where foobar = :foobar",
    79  		outbv: map[string]*querypb.BindVariable{
    80  			"foobar": sqltypes.DecimalBindVariable(1.2),
    81  		},
    82  	}, {
    83  		// multiple vals
    84  		in:      "select * from t where foo = 1.2 and bar = 2",
    85  		outstmt: "select * from t where foo = :foo and bar = :bar",
    86  		outbv: map[string]*querypb.BindVariable{
    87  			"foo": sqltypes.DecimalBindVariable(1.2),
    88  			"bar": sqltypes.Int64BindVariable(2),
    89  		},
    90  	}, {
    91  		// bv collision
    92  		in:      "select * from t where foo = :bar and bar = 12",
    93  		outstmt: "select * from t where foo = :bar and bar = :bar1",
    94  		outbv: map[string]*querypb.BindVariable{
    95  			"bar1": sqltypes.Int64BindVariable(12),
    96  		},
    97  	}, {
    98  		// val reuse
    99  		in:      "select * from t where foo = 1 and bar = 1",
   100  		outstmt: "select * from t where foo = :foo and bar = :foo",
   101  		outbv: map[string]*querypb.BindVariable{
   102  			"foo": sqltypes.Int64BindVariable(1),
   103  		},
   104  	}, {
   105  		// ints and strings are different
   106  		in:      "select * from t where foo = 1 and bar = '1'",
   107  		outstmt: "select * from t where foo = :foo and bar = :bar",
   108  		outbv: map[string]*querypb.BindVariable{
   109  			"foo": sqltypes.Int64BindVariable(1),
   110  			"bar": sqltypes.StringBindVariable("1"),
   111  		},
   112  	}, {
   113  		// val should not be reused for non-select statements
   114  		in:      "insert into a values(1, 1)",
   115  		outstmt: "insert into a values (:bv1, :bv2)",
   116  		outbv: map[string]*querypb.BindVariable{
   117  			"bv1": sqltypes.Int64BindVariable(1),
   118  			"bv2": sqltypes.Int64BindVariable(1),
   119  		},
   120  	}, {
   121  		// val should be reused only in subqueries of DMLs
   122  		in:      "update a set v1=(select 5 from t), v2=5, v3=(select 5 from t), v4=5",
   123  		outstmt: "update a set v1 = (select :bv1 from t), v2 = :bv1, v3 = (select :bv1 from t), v4 = :bv1",
   124  		outbv: map[string]*querypb.BindVariable{
   125  			"bv1": sqltypes.Int64BindVariable(5),
   126  		},
   127  	}, {
   128  		// list vars should work for DMLs also
   129  		in:      "update a set v1=5 where v2 in (1, 4, 5)",
   130  		outstmt: "update a set v1 = :v1 where v2 in ::bv1",
   131  		outbv: map[string]*querypb.BindVariable{
   132  			"v1":  sqltypes.Int64BindVariable(5),
   133  			"bv1": sqltypes.TestBindVariable([]any{1, 4, 5}),
   134  		},
   135  	}, {
   136  		// Hex number values should work for selects
   137  		in:      "select * from t where foo = 0x1234",
   138  		outstmt: "select * from t where foo = :foo",
   139  		outbv: map[string]*querypb.BindVariable{
   140  			"foo": sqltypes.HexNumBindVariable([]byte("0x1234")),
   141  		},
   142  	}, {
   143  		// Hex encoded string values should work for selects
   144  		in:      "select * from t where foo = x'7b7d'",
   145  		outstmt: "select * from t where foo = :foo",
   146  		outbv: map[string]*querypb.BindVariable{
   147  			"foo": sqltypes.HexValBindVariable([]byte("x'7b7d'")),
   148  		},
   149  	}, {
   150  		// Ensure that hex notation bind vars work with collation based conversions
   151  		in:      "select convert(x'7b7d' using utf8mb4) from dual",
   152  		outstmt: "select convert(:bv1 using utf8mb4) from dual",
   153  		outbv: map[string]*querypb.BindVariable{
   154  			"bv1": sqltypes.HexValBindVariable([]byte("x'7b7d'")),
   155  		},
   156  	}, {
   157  		// Hex number values should work for DMLs
   158  		in:      "update a set foo = 0x12",
   159  		outstmt: "update a set foo = :foo",
   160  		outbv: map[string]*querypb.BindVariable{
   161  			"foo": sqltypes.HexNumBindVariable([]byte("0x12")),
   162  		},
   163  	}, {
   164  		// Bin values work fine
   165  		in:      "select * from t where foo = b'11'",
   166  		outstmt: "select * from t where foo = :foo",
   167  		outbv: map[string]*querypb.BindVariable{
   168  			"foo": sqltypes.HexNumBindVariable([]byte("0x3")),
   169  		},
   170  	}, {
   171  		// Bin value does not convert for DMLs
   172  		in:      "update a set v1 = b'11'",
   173  		outstmt: "update a set v1 = :v1",
   174  		outbv: map[string]*querypb.BindVariable{
   175  			"v1": sqltypes.HexNumBindVariable([]byte("0x3")),
   176  		},
   177  	}, {
   178  		// ORDER BY column_position
   179  		in:      "select a, b from t order by 1 asc",
   180  		outstmt: "select a, b from t order by 1 asc",
   181  		outbv:   map[string]*querypb.BindVariable{},
   182  	}, {
   183  		// GROUP BY column_position
   184  		in:      "select a, b from t group by 1",
   185  		outstmt: "select a, b from t group by 1",
   186  		outbv:   map[string]*querypb.BindVariable{},
   187  	}, {
   188  		// ORDER BY with literal inside complex expression
   189  		in:      "select a, b from t order by field(a,1,2,3) asc",
   190  		outstmt: "select a, b from t order by field(a, :bv1, :bv2, :bv3) asc",
   191  		outbv: map[string]*querypb.BindVariable{
   192  			"bv1": sqltypes.Int64BindVariable(1),
   193  			"bv2": sqltypes.Int64BindVariable(2),
   194  			"bv3": sqltypes.Int64BindVariable(3),
   195  		},
   196  	}, {
   197  		// ORDER BY variable
   198  		in:      "select a, b from t order by c asc",
   199  		outstmt: "select a, b from t order by c asc",
   200  		outbv:   map[string]*querypb.BindVariable{},
   201  	}, {
   202  		// Values up to len 256 will reuse.
   203  		in:      fmt.Sprintf("select * from t where foo = '%256s' and bar = '%256s'", "a", "a"),
   204  		outstmt: "select * from t where foo = :foo and bar = :foo",
   205  		outbv: map[string]*querypb.BindVariable{
   206  			"foo": sqltypes.StringBindVariable(fmt.Sprintf("%256s", "a")),
   207  		},
   208  	}, {
   209  		// Values greater than len 256 will not reuse.
   210  		in:      fmt.Sprintf("select * from t where foo = '%257s' and bar = '%257s'", "b", "b"),
   211  		outstmt: "select * from t where foo = :foo and bar = :bar",
   212  		outbv: map[string]*querypb.BindVariable{
   213  			"foo": sqltypes.StringBindVariable(fmt.Sprintf("%257s", "b")),
   214  			"bar": sqltypes.StringBindVariable(fmt.Sprintf("%257s", "b")),
   215  		},
   216  	}, {
   217  		// bad int
   218  		in:      "select * from t where v1 = 12345678901234567890",
   219  		outstmt: "select * from t where v1 = 12345678901234567890",
   220  		outbv:   map[string]*querypb.BindVariable{},
   221  	}, {
   222  		// comparison with no vals
   223  		in:      "select * from t where v1 = v2",
   224  		outstmt: "select * from t where v1 = v2",
   225  		outbv:   map[string]*querypb.BindVariable{},
   226  	}, {
   227  		// IN clause with existing bv
   228  		in:      "select * from t where v1 in ::list",
   229  		outstmt: "select * from t where v1 in ::list",
   230  		outbv:   map[string]*querypb.BindVariable{},
   231  	}, {
   232  		// IN clause with non-val values
   233  		in:      "select * from t where v1 in (1, a)",
   234  		outstmt: "select * from t where v1 in (:bv1, a)",
   235  		outbv: map[string]*querypb.BindVariable{
   236  			"bv1": sqltypes.Int64BindVariable(1),
   237  		},
   238  	}, {
   239  		// IN clause with vals
   240  		in:      "select * from t where v1 in (1, '2')",
   241  		outstmt: "select * from t where v1 in ::bv1",
   242  		outbv: map[string]*querypb.BindVariable{
   243  			"bv1": sqltypes.TestBindVariable([]any{1, "2"}),
   244  		},
   245  	}, {
   246  		// EXPLAIN queries
   247  		in:      "explain select * from t where v1 in (1, '2')",
   248  		outstmt: "explain select * from t where v1 in ::bv1",
   249  		outbv: map[string]*querypb.BindVariable{
   250  			"bv1": sqltypes.TestBindVariable([]any{1, "2"}),
   251  		},
   252  	}, {
   253  		// NOT IN clause
   254  		in:      "select * from t where v1 not in (1, '2')",
   255  		outstmt: "select * from t where v1 not in ::bv1",
   256  		outbv: map[string]*querypb.BindVariable{
   257  			"bv1": sqltypes.TestBindVariable([]any{1, "2"}),
   258  		},
   259  	}, {
   260  		// Do not normalize cast/convert types
   261  		in:      `select CAST("test" AS CHAR(60))`,
   262  		outstmt: `select cast(:bv1 as CHAR(60)) from dual`,
   263  		outbv: map[string]*querypb.BindVariable{
   264  			"bv1": sqltypes.StringBindVariable("test"),
   265  		},
   266  	}, {
   267  		// insert syntax
   268  		in:      "insert into a (v1, v2, v3) values (1, '2', 3)",
   269  		outstmt: "insert into a(v1, v2, v3) values (:bv1, :bv2, :bv3)",
   270  		outbv: map[string]*querypb.BindVariable{
   271  			"bv1": sqltypes.Int64BindVariable(1),
   272  			"bv2": sqltypes.StringBindVariable("2"),
   273  			"bv3": sqltypes.Int64BindVariable(3),
   274  		},
   275  	}, {
   276  		// BitVal should also be normalized
   277  		in:      `select b'1', 0b01, b'1010', 0b1111111`,
   278  		outstmt: `select :bv1, :bv2, :bv3, :bv4 from dual`,
   279  		outbv: map[string]*querypb.BindVariable{
   280  			"bv1": sqltypes.HexNumBindVariable([]byte("0x1")),
   281  			"bv2": sqltypes.HexNumBindVariable([]byte("0x1")),
   282  			"bv3": sqltypes.HexNumBindVariable([]byte("0xa")),
   283  			"bv4": sqltypes.HexNumBindVariable([]byte("0x7f")),
   284  		},
   285  	}, {
   286  		// DateVal should also be normalized
   287  		in:      `select date'2022-08-06'`,
   288  		outstmt: `select :bv1 from dual`,
   289  		outbv: map[string]*querypb.BindVariable{
   290  			"bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Date, []byte("2022-08-06"))),
   291  		},
   292  	}, {
   293  		// TimeVal should also be normalized
   294  		in:      `select time'17:05:12'`,
   295  		outstmt: `select :bv1 from dual`,
   296  		outbv: map[string]*querypb.BindVariable{
   297  			"bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Time, []byte("17:05:12"))),
   298  		},
   299  	}, {
   300  		// TimestampVal should also be normalized
   301  		in:      `select timestamp'2022-08-06 17:05:12'`,
   302  		outstmt: `select :bv1 from dual`,
   303  		outbv: map[string]*querypb.BindVariable{
   304  			"bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Datetime, []byte("2022-08-06 17:05:12"))),
   305  		},
   306  	}, {
   307  		// TimestampVal should also be normalized
   308  		in:      `explain select comms_by_companies.* from comms_by_companies where comms_by_companies.id = 'rjve634shXzaavKHbAH16ql6OrxJ' limit 1,1`,
   309  		outstmt: `explain select comms_by_companies.* from comms_by_companies where comms_by_companies.id = :comms_by_companies_id limit :bv1, :bv2`,
   310  		outbv: map[string]*querypb.BindVariable{
   311  			"bv1":                   sqltypes.Int64BindVariable(1),
   312  			"bv2":                   sqltypes.Int64BindVariable(1),
   313  			"comms_by_companies_id": sqltypes.StringBindVariable("rjve634shXzaavKHbAH16ql6OrxJ"),
   314  		},
   315  	}, {
   316  		// Int leading with zero should also be normalized
   317  		in:      `select * from t where zipcode = 01001900`,
   318  		outstmt: `select * from t where zipcode = :zipcode`,
   319  		outbv: map[string]*querypb.BindVariable{
   320  			"zipcode": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Int64, []byte("01001900"))),
   321  		},
   322  	}, {
   323  		// literals in limit and offset should not reuse bindvars
   324  		in:      `select * from t where id = 10 limit 10 offset 10`,
   325  		outstmt: `select * from t where id = :id limit :bv1, :bv2`,
   326  		outbv: map[string]*querypb.BindVariable{
   327  			"bv1": sqltypes.Int64BindVariable(10),
   328  			"bv2": sqltypes.Int64BindVariable(10),
   329  			"id":  sqltypes.Int64BindVariable(10),
   330  		},
   331  	}, {
   332  		// we don't want to replace literals on the select expressions of a derived table
   333  		// these expressions can be referenced from the outside,
   334  		// and changing them to bindvars can change the meaning of the query
   335  		// example of problematic query: select tmp.`1` from (select 1) as tmp
   336  		in:      `select * from (select 12) as t`,
   337  		outstmt: `select * from (select 12 from dual) as t`,
   338  		outbv:   map[string]*querypb.BindVariable{},
   339  	}}
   340  	for _, tc := range testcases {
   341  		t.Run(tc.in, func(t *testing.T) {
   342  			stmt, err := Parse(tc.in)
   343  			require.NoError(t, err)
   344  			known := GetBindvars(stmt)
   345  			bv := make(map[string]*querypb.BindVariable)
   346  			require.NoError(t, Normalize(stmt, NewReservedVars(prefix, known), bv))
   347  			assert.Equal(t, tc.outstmt, String(stmt))
   348  			assert.Equal(t, tc.outbv, bv)
   349  		})
   350  	}
   351  }
   352  
   353  func TestNormalizeInvalidDates(t *testing.T) {
   354  	testcases := []struct {
   355  		in  string
   356  		err error
   357  	}{{
   358  		in:  "select date'foo'",
   359  		err: vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.WrongValue, "incorrect DATE value: '%s'", "foo"),
   360  	}, {
   361  		in:  "select time'foo'",
   362  		err: vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.WrongValue, "incorrect TIME value: '%s'", "foo"),
   363  	}, {
   364  		in:  "select timestamp'foo'",
   365  		err: vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.WrongValue, "incorrect DATETIME value: '%s'", "foo"),
   366  	}}
   367  	for _, tc := range testcases {
   368  		t.Run(tc.in, func(t *testing.T) {
   369  			stmt, err := Parse(tc.in)
   370  			require.NoError(t, err)
   371  			known := GetBindvars(stmt)
   372  			bv := make(map[string]*querypb.BindVariable)
   373  			require.EqualError(t, Normalize(stmt, NewReservedVars("bv", known), bv), tc.err.Error())
   374  		})
   375  	}
   376  }
   377  
   378  func TestNormalizeValidSQL(t *testing.T) {
   379  	for _, tcase := range validSQL {
   380  		t.Run(tcase.input, func(t *testing.T) {
   381  			if tcase.partialDDL || tcase.ignoreNormalizerTest {
   382  				return
   383  			}
   384  			tree, err := Parse(tcase.input)
   385  			require.NoError(t, err, tcase.input)
   386  			// Skip the test for the queries that do not run the normalizer
   387  			if !CanNormalize(tree) {
   388  				return
   389  			}
   390  			bv := make(map[string]*querypb.BindVariable)
   391  			known := make(BindVars)
   392  			err = Normalize(tree, NewReservedVars("vtg", known), bv)
   393  			require.NoError(t, err)
   394  			normalizerOutput := String(tree)
   395  			if normalizerOutput == "otheradmin" || normalizerOutput == "otherread" {
   396  				return
   397  			}
   398  			_, err = Parse(normalizerOutput)
   399  			require.NoError(t, err, normalizerOutput)
   400  		})
   401  	}
   402  }
   403  
   404  func TestGetBindVars(t *testing.T) {
   405  	stmt, err := Parse("select * from t where :v1 = :v2 and :v2 = :v3 and :v4 in ::v5")
   406  	if err != nil {
   407  		t.Fatal(err)
   408  	}
   409  	got := GetBindvars(stmt)
   410  	want := map[string]struct{}{
   411  		"v1": {},
   412  		"v2": {},
   413  		"v3": {},
   414  		"v4": {},
   415  		"v5": {},
   416  	}
   417  	if !reflect.DeepEqual(got, want) {
   418  		t.Errorf("GetBindVars: %v, want: %v", got, want)
   419  	}
   420  }
   421  
   422  /*
   423  Skipping ColName, TableName:
   424  BenchmarkNormalize-8     1000000              2205 ns/op             821 B/op         27 allocs/op
   425  Prior to skip:
   426  BenchmarkNormalize-8      500000              3620 ns/op            1461 B/op         55 allocs/op
   427  */
   428  func BenchmarkNormalize(b *testing.B) {
   429  	sql := "select 'abcd', 20, 30.0, eid from a where 1=eid and name='3'"
   430  	ast, reservedVars, err := Parse2(sql)
   431  	if err != nil {
   432  		b.Fatal(err)
   433  	}
   434  	for i := 0; i < b.N; i++ {
   435  		require.NoError(b, Normalize(ast, NewReservedVars("", reservedVars), map[string]*querypb.BindVariable{}))
   436  	}
   437  }
   438  
   439  func BenchmarkNormalizeTraces(b *testing.B) {
   440  	for _, trace := range []string{"django_queries.txt", "lobsters.sql.gz"} {
   441  		b.Run(trace, func(b *testing.B) {
   442  			queries := loadQueries(b, trace)
   443  			if len(queries) > 10000 {
   444  				queries = queries[:10000]
   445  			}
   446  
   447  			parsed := make([]Statement, 0, len(queries))
   448  			reservedVars := make([]BindVars, 0, len(queries))
   449  			for _, q := range queries {
   450  				pp, kb, err := Parse2(q)
   451  				if err != nil {
   452  					b.Fatal(err)
   453  				}
   454  				parsed = append(parsed, pp)
   455  				reservedVars = append(reservedVars, kb)
   456  			}
   457  
   458  			b.ResetTimer()
   459  			b.ReportAllocs()
   460  
   461  			for i := 0; i < b.N; i++ {
   462  				for i, query := range parsed {
   463  					_ = Normalize(query, NewReservedVars("", reservedVars[i]), map[string]*querypb.BindVariable{})
   464  				}
   465  			}
   466  		})
   467  	}
   468  }
   469  
   470  func BenchmarkNormalizeVTGate(b *testing.B) {
   471  	const keyspace = "main_keyspace"
   472  
   473  	queries := loadQueries(b, "lobsters.sql.gz")
   474  	if len(queries) > 10000 {
   475  		queries = queries[:10000]
   476  	}
   477  
   478  	b.ResetTimer()
   479  	b.ReportAllocs()
   480  
   481  	for i := 0; i < b.N; i++ {
   482  		for _, sql := range queries {
   483  			stmt, reservedVars, err := Parse2(sql)
   484  			if err != nil {
   485  				b.Fatal(err)
   486  			}
   487  
   488  			query := sql
   489  			statement := stmt
   490  			bindVarNeeds := &BindVarNeeds{}
   491  			bindVars := make(map[string]*querypb.BindVariable)
   492  			_ = IgnoreMaxMaxMemoryRowsDirective(stmt)
   493  
   494  			// Normalize if possible and retry.
   495  			if CanNormalize(stmt) || MustRewriteAST(stmt, false) {
   496  				result, err := PrepareAST(
   497  					stmt,
   498  					NewReservedVars("vtg", reservedVars),
   499  					bindVars,
   500  					true,
   501  					keyspace,
   502  					SQLSelectLimitUnset,
   503  					"",
   504  					nil, /*sysvars*/
   505  					nil, /*views*/
   506  				)
   507  				if err != nil {
   508  					b.Fatal(err)
   509  				}
   510  				statement = result.AST
   511  				bindVarNeeds = result.BindVarNeeds
   512  				query = String(statement)
   513  			}
   514  
   515  			_ = query
   516  			_ = statement
   517  			_ = bindVarNeeds
   518  		}
   519  	}
   520  }
   521  
   522  func randtmpl(template string) string {
   523  	const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
   524  	const numberBytes = "0123456789"
   525  
   526  	result := []byte(template)
   527  	for i, c := range result {
   528  		switch c {
   529  		case '#':
   530  			result[i] = numberBytes[rand.Intn(len(numberBytes))]
   531  		case '@':
   532  			result[i] = letterBytes[rand.Intn(len(letterBytes))]
   533  		}
   534  	}
   535  	return string(result)
   536  }
   537  
   538  func randString(n int) string {
   539  	const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
   540  	b := make([]byte, n)
   541  	for i := range b {
   542  		b[i] = letterBytes[rand.Intn(len(letterBytes))]
   543  	}
   544  	return string(b)
   545  }
   546  
   547  func BenchmarkNormalizeTPCCBinds(b *testing.B) {
   548  	query := `INSERT IGNORE INTO customer0
   549  (c_id, c_d_id, c_w_id, c_first, c_middle, c_last, c_street_1, c_street_2, c_city, c_state, c_zip, c_phone, c_since, c_credit, c_credit_lim, c_discount, c_balance, c_ytd_payment, c_payment_cnt, c_delivery_cnt, c_data)
   550  values
   551  (:c_id, :c_d_id, :c_w_id, :c_first, :c_middle, :c_last, :c_street_1, :c_street_2, :c_city, :c_state, :c_zip, :c_phone, :c_since, :c_credit, :c_credit_lim, :c_discount, :c_balance, :c_ytd_payment, :c_payment_cnt, :c_delivery_cnt, :c_data)`
   552  	benchmarkNormalization(b, []string{query})
   553  }
   554  
   555  func BenchmarkNormalizeTPCCInsert(b *testing.B) {
   556  	generateInsert := func(rows int) string {
   557  		var query bytes.Buffer
   558  		query.WriteString("INSERT IGNORE INTO customer0 (c_id, c_d_id, c_w_id, c_first, c_middle, c_last, c_street_1, c_street_2, c_city, c_state, c_zip, c_phone, c_since, c_credit, c_credit_lim, c_discount, c_balance, c_ytd_payment, c_payment_cnt, c_delivery_cnt, c_data) values ")
   559  		for i := 0; i < rows; i++ {
   560  			fmt.Fprintf(&query, "(%d, %d, %d, '%s','OE','%s','%s', '%s', '%s', '%s', '%s','%s',NOW(),'%s',50000,%f,-10,10,1,0,'%s' )",
   561  				rand.Int(), rand.Int(), rand.Int(),
   562  				"first-"+randString(rand.Intn(10)),
   563  				randtmpl("last-@@@@"),
   564  				randtmpl("street1-@@@@@@@@@@@@"),
   565  				randtmpl("street2-@@@@@@@@@@@@"),
   566  				randtmpl("city-@@@@@@@@@@@@"),
   567  				randtmpl("@@"), randtmpl("zip-#####"),
   568  				randtmpl("################"),
   569  				"GC", rand.Float64(), randString(300+rand.Intn(200)),
   570  			)
   571  			if i < rows-1 {
   572  				query.WriteString(", ")
   573  			}
   574  		}
   575  		return query.String()
   576  	}
   577  
   578  	var queries []string
   579  
   580  	for i := 0; i < 1024; i++ {
   581  		queries = append(queries, generateInsert(4))
   582  	}
   583  
   584  	benchmarkNormalization(b, queries)
   585  }
   586  
   587  func BenchmarkNormalizeTPCC(b *testing.B) {
   588  	templates := []string{
   589  		`SELECT c_discount, c_last, c_credit, w_tax
   590  FROM customer%d AS c
   591  JOIN warehouse%d AS w ON c_w_id=w_id
   592  WHERE w_id = %d
   593  AND c_d_id = %d
   594  AND c_id = %d`,
   595  		`SELECT d_next_o_id, d_tax 
   596  FROM district%d 
   597  WHERE d_w_id = %d 
   598  AND d_id = %d FOR UPDATE`,
   599  		`UPDATE district%d
   600  SET d_next_o_id = %d
   601  WHERE d_id = %d AND d_w_id= %d`,
   602  		`INSERT INTO orders%d
   603  (o_id, o_d_id, o_w_id, o_c_id,  o_entry_d, o_ol_cnt, o_all_local)
   604  VALUES (%d,%d,%d,%d,NOW(),%d,%d)`,
   605  		`INSERT INTO new_orders%d (no_o_id, no_d_id, no_w_id)
   606  VALUES (%d,%d,%d)`,
   607  		`SELECT i_price, i_name, i_data 
   608  FROM item%d
   609  WHERE i_id = %d`,
   610  		`SELECT s_quantity, s_data, s_dist_%s s_dist 
   611  FROM stock%d  
   612  WHERE s_i_id = %d AND s_w_id= %d FOR UPDATE`,
   613  		`UPDATE stock%d
   614  SET s_quantity = %d
   615  WHERE s_i_id = %d 
   616  AND s_w_id= %d`,
   617  		`INSERT INTO order_line%d
   618  (ol_o_id, ol_d_id, ol_w_id, ol_number, ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_dist_info)
   619  VALUES (%d,%d,%d,%d,%d,%d,%d,%d,'%s')`,
   620  		`UPDATE warehouse%d
   621  SET w_ytd = w_ytd + %d 
   622  WHERE w_id = %d`,
   623  		`SELECT w_street_1, w_street_2, w_city, w_state, w_zip, w_name
   624  FROM warehouse%d
   625  WHERE w_id = %d`,
   626  		`UPDATE district%d 
   627  SET d_ytd = d_ytd + %d 
   628  WHERE d_w_id = %d 
   629  AND d_id= %d`,
   630  		`SELECT d_street_1, d_street_2, d_city, d_state, d_zip, d_name 
   631  FROM district%d
   632  WHERE d_w_id = %d 
   633  AND d_id = %d`,
   634  		`SELECT count(c_id) namecnt
   635  FROM customer%d
   636  WHERE c_w_id = %d 
   637  AND c_d_id= %d
   638  AND c_last='%s'`,
   639  		`SELECT c_first, c_middle, c_last, c_street_1,
   640  c_street_2, c_city, c_state, c_zip, c_phone,
   641  c_credit, c_credit_lim, c_discount, c_balance, c_ytd_payment, c_since
   642  FROM customer%d
   643  WHERE c_w_id = %d 
   644  AND c_d_id= %d
   645  AND c_id=%d FOR UPDATE`,
   646  		`SELECT c_data
   647  FROM customer%d
   648  WHERE c_w_id = %d 
   649  AND c_d_id=%d
   650  AND c_id= %d`,
   651  		`UPDATE customer%d
   652  SET c_balance=%f, c_ytd_payment=%f, c_data='%s'
   653  WHERE c_w_id = %d 
   654  AND c_d_id=%d
   655  AND c_id=%d`,
   656  		`UPDATE customer%d
   657  SET c_balance=%f, c_ytd_payment=%f
   658  WHERE c_w_id = %d 
   659  AND c_d_id=%d
   660  AND c_id=%d`,
   661  		`INSERT INTO history%d
   662  (h_c_d_id, h_c_w_id, h_c_id, h_d_id,  h_w_id, h_date, h_amount, h_data)
   663  VALUES (%d,%d,%d,%d,%d,NOW(),%d,'%s')`,
   664  		`SELECT count(c_id) namecnt
   665  FROM customer%d
   666  WHERE c_w_id = %d 
   667  AND c_d_id= %d
   668  AND c_last='%s'`,
   669  		`SELECT c_balance, c_first, c_middle, c_id
   670  FROM customer%d
   671  WHERE c_w_id = %d 
   672  AND c_d_id= %d
   673  AND c_last='%s' ORDER BY c_first`,
   674  		`SELECT c_balance, c_first, c_middle, c_last
   675  FROM customer%d
   676  WHERE c_w_id = %d 
   677  AND c_d_id=%d
   678  AND c_id=%d`,
   679  		`SELECT o_id, o_carrier_id, o_entry_d
   680  FROM orders%d 
   681  WHERE o_w_id = %d 
   682  AND o_d_id = %d 
   683  AND o_c_id = %d 
   684  ORDER BY o_id DESC`,
   685  		`SELECT ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_delivery_d
   686  FROM order_line%d WHERE ol_w_id = %d AND ol_d_id = %d  AND ol_o_id = %d`,
   687  		`SELECT no_o_id
   688  FROM new_orders%d 
   689  WHERE no_d_id = %d 
   690  AND no_w_id = %d 
   691  ORDER BY no_o_id ASC LIMIT 1 FOR UPDATE`,
   692  		`DELETE FROM new_orders%d
   693  WHERE no_o_id = %d 
   694  AND no_d_id = %d  
   695  AND no_w_id = %d`,
   696  		`SELECT o_c_id
   697  FROM orders%d 
   698  WHERE o_id = %d 
   699  AND o_d_id = %d 
   700  AND o_w_id = %d`,
   701  		`UPDATE orders%d 
   702  SET o_carrier_id = %d
   703  WHERE o_id = %d 
   704  AND o_d_id = %d 
   705  AND o_w_id = %d`,
   706  		`UPDATE order_line%d 
   707  SET ol_delivery_d = NOW()
   708  WHERE ol_o_id = %d 
   709  AND ol_d_id = %d 
   710  AND ol_w_id = %d`,
   711  		`SELECT SUM(ol_amount) sm
   712  FROM order_line%d 
   713  WHERE ol_o_id = %d 
   714  AND ol_d_id = %d 
   715  AND ol_w_id = %d`,
   716  		`UPDATE customer%d 
   717  SET c_balance = c_balance + %f,
   718  c_delivery_cnt = c_delivery_cnt + 1
   719  WHERE c_id = %d 
   720  AND c_d_id = %d 
   721  AND c_w_id = %d`,
   722  		`SELECT d_next_o_id 
   723  FROM district%d
   724  WHERE d_id = %d AND d_w_id= %d`,
   725  		`SELECT COUNT(DISTINCT(s.s_i_id))
   726  FROM stock%d AS s
   727  JOIN order_line%d AS ol ON ol.ol_w_id=s.s_w_id AND ol.ol_i_id=s.s_i_id			
   728  WHERE ol.ol_w_id = %d 
   729  AND ol.ol_d_id = %d
   730  AND ol.ol_o_id < %d 
   731  AND ol.ol_o_id >= %d
   732  AND s.s_w_id= %d
   733  AND s.s_quantity < %d `,
   734  		`SELECT DISTINCT ol_i_id FROM order_line%d
   735  WHERE ol_w_id = %d AND ol_d_id = %d
   736  AND ol_o_id < %d AND ol_o_id >= %d`,
   737  		`SELECT count(*) FROM stock%d
   738  WHERE s_w_id = %d AND s_i_id = %d
   739  AND s_quantity < %d`,
   740  		`SELECT min(no_o_id) mo
   741  FROM new_orders%d 
   742  WHERE no_w_id = %d AND no_d_id = %d`,
   743  		`SELECT o_id FROM orders%d o, (SELECT o_c_id,o_w_id,o_d_id,count(distinct o_id) FROM orders%d WHERE o_w_id=%d AND o_d_id=%d AND o_id > 2100 AND o_id < %d GROUP BY o_c_id,o_d_id,o_w_id having count( distinct o_id) > 1 limit 1) t WHERE t.o_w_id=o.o_w_id and t.o_d_id=o.o_d_id and t.o_c_id=o.o_c_id limit 1 `,
   744  		`DELETE FROM order_line%d where ol_w_id=%d AND ol_d_id=%d AND ol_o_id=%d`,
   745  		`DELETE FROM orders%d where o_w_id=%d AND o_d_id=%d and o_id=%d`,
   746  		`DELETE FROM history%d where h_w_id=%d AND h_d_id=%d LIMIT 10`,
   747  	}
   748  
   749  	re := regexp.MustCompile(`%\w`)
   750  	repl := func(m string) string {
   751  		switch m {
   752  		case "%s":
   753  			return "RANDOM_STRING"
   754  		case "%d":
   755  			return strconv.Itoa(rand.Int())
   756  		case "%f":
   757  			return strconv.FormatFloat(rand.Float64(), 'f', 8, 64)
   758  		default:
   759  			panic(m)
   760  		}
   761  	}
   762  
   763  	var queries []string
   764  
   765  	for _, tmpl := range templates {
   766  		for i := 0; i < 128; i++ {
   767  			queries = append(queries, re.ReplaceAllStringFunc(tmpl, repl))
   768  		}
   769  	}
   770  
   771  	benchmarkNormalization(b, queries)
   772  }
   773  
   774  func benchmarkNormalization(b *testing.B, sqls []string) {
   775  	b.Helper()
   776  	b.ReportAllocs()
   777  	b.ResetTimer()
   778  	for i := 0; i < b.N; i++ {
   779  		for _, sql := range sqls {
   780  			stmt, reserved, err := Parse2(sql)
   781  			if err != nil {
   782  				b.Fatalf("%v: %q", err, sql)
   783  			}
   784  
   785  			reservedVars := NewReservedVars("vtg", reserved)
   786  			_, err = PrepareAST(
   787  				stmt,
   788  				reservedVars,
   789  				make(map[string]*querypb.BindVariable),
   790  				true,
   791  				"keyspace0",
   792  				SQLSelectLimitUnset,
   793  				"",
   794  				nil,
   795  				nil,
   796  			)
   797  			if err != nil {
   798  				b.Fatal(err)
   799  			}
   800  		}
   801  	}
   802  }