vitess.io/vitess@v0.16.2/go/vt/sqlparser/parsed_query_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"reflect"
    21  	"testing"
    22  
    23  	"vitess.io/vitess/go/sqltypes"
    24  	querypb "vitess.io/vitess/go/vt/proto/query"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  )
    28  
    29  func TestNewParsedQuery(t *testing.T) {
    30  	stmt, err := Parse("select * from a where id =:id")
    31  	if err != nil {
    32  		t.Error(err)
    33  		return
    34  	}
    35  	pq := NewParsedQuery(stmt)
    36  	want := &ParsedQuery{
    37  		Query:         "select * from a where id = :id",
    38  		bindLocations: []bindLocation{{offset: 27, length: 3}},
    39  	}
    40  	if !reflect.DeepEqual(pq, want) {
    41  		t.Errorf("GenerateParsedQuery: %+v, want %+v", pq, want)
    42  	}
    43  }
    44  
    45  func TestGenerateQuery(t *testing.T) {
    46  	tcases := []struct {
    47  		desc     string
    48  		query    string
    49  		bindVars map[string]*querypb.BindVariable
    50  		extras   map[string]Encodable
    51  		output   string
    52  	}{
    53  		{
    54  			desc:  "no substitutions",
    55  			query: "select * from a where id = 2",
    56  			bindVars: map[string]*querypb.BindVariable{
    57  				"id": sqltypes.Int64BindVariable(1),
    58  			},
    59  			output: "select * from a where id = 2",
    60  		}, {
    61  			desc:  "missing bind var",
    62  			query: "select * from a where id1 = :id1 and id2 = :id2",
    63  			bindVars: map[string]*querypb.BindVariable{
    64  				"id1": sqltypes.Int64BindVariable(1),
    65  			},
    66  			output: "missing bind var id2",
    67  		}, {
    68  			desc:  "simple bindvar substitution",
    69  			query: "select * from a where id1 = :id1 and id2 = :id2",
    70  			bindVars: map[string]*querypb.BindVariable{
    71  				"id1": sqltypes.Int64BindVariable(1),
    72  				"id2": sqltypes.NullBindVariable,
    73  			},
    74  			output: "select * from a where id1 = 1 and id2 = null",
    75  		}, {
    76  			desc:  "tuple *querypb.BindVariable",
    77  			query: "select * from a where id in ::vals",
    78  			bindVars: map[string]*querypb.BindVariable{
    79  				"vals": sqltypes.TestBindVariable([]any{1, "aa"}),
    80  			},
    81  			output: "select * from a where id in (1, 'aa')",
    82  		}, {
    83  			desc:  "list bind vars 0 arguments",
    84  			query: "select * from a where id in ::vals",
    85  			bindVars: map[string]*querypb.BindVariable{
    86  				"vals": sqltypes.TestBindVariable([]any{}),
    87  			},
    88  			output: "empty list supplied for vals",
    89  		}, {
    90  			desc:  "non-list bind var supplied",
    91  			query: "select * from a where id in ::vals",
    92  			bindVars: map[string]*querypb.BindVariable{
    93  				"vals": sqltypes.Int64BindVariable(1),
    94  			},
    95  			output: "unexpected list arg type (INT64) for key vals",
    96  		}, {
    97  			desc:  "list bind var for non-list",
    98  			query: "select * from a where id = :vals",
    99  			bindVars: map[string]*querypb.BindVariable{
   100  				"vals": sqltypes.TestBindVariable([]any{1}),
   101  			},
   102  			output: "unexpected arg type (TUPLE) for non-list key vals",
   103  		}, {
   104  			desc:  "single column tuple equality",
   105  			query: "select * from a where b = :equality",
   106  			extras: map[string]Encodable{
   107  				"equality": &TupleEqualityList{
   108  					Columns: []IdentifierCI{NewIdentifierCI("pk")},
   109  					Rows: [][]sqltypes.Value{
   110  						{sqltypes.NewInt64(1)},
   111  						{sqltypes.NewVarBinary("aa")},
   112  					},
   113  				},
   114  			},
   115  			output: "select * from a where b = pk in (1, 'aa')",
   116  		}, {
   117  			desc:  "multi column tuple equality",
   118  			query: "select * from a where b = :equality",
   119  			extras: map[string]Encodable{
   120  				"equality": &TupleEqualityList{
   121  					Columns: []IdentifierCI{NewIdentifierCI("pk1"), NewIdentifierCI("pk2")},
   122  					Rows: [][]sqltypes.Value{
   123  						{
   124  							sqltypes.NewInt64(1),
   125  							sqltypes.NewVarBinary("aa"),
   126  						},
   127  						{
   128  							sqltypes.NewInt64(2),
   129  							sqltypes.NewVarBinary("bb"),
   130  						},
   131  					},
   132  				},
   133  			},
   134  			output: "select * from a where b = (pk1 = 1 and pk2 = 'aa') or (pk1 = 2 and pk2 = 'bb')",
   135  		},
   136  	}
   137  
   138  	for _, tcase := range tcases {
   139  		tree, err := Parse(tcase.query)
   140  		if err != nil {
   141  			t.Errorf("parse failed for %s: %v", tcase.desc, err)
   142  			continue
   143  		}
   144  		buf := NewTrackedBuffer(nil)
   145  		buf.Myprintf("%v", tree)
   146  		pq := buf.ParsedQuery()
   147  		bytes, err := pq.GenerateQuery(tcase.bindVars, tcase.extras)
   148  		if err != nil {
   149  			assert.Equal(t, tcase.output, err.Error())
   150  		} else {
   151  			assert.Equal(t, tcase.output, string(bytes))
   152  		}
   153  	}
   154  }
   155  
   156  func TestParseAndBind(t *testing.T) {
   157  	testcases := []struct {
   158  		in    string
   159  		binds []*querypb.BindVariable
   160  		out   string
   161  	}{
   162  		{
   163  			in:  "select * from tbl",
   164  			out: "select * from tbl",
   165  		}, {
   166  			in:  "select * from tbl where b=4 or a=3",
   167  			out: "select * from tbl where b=4 or a=3",
   168  		}, {
   169  			in:  "select * from tbl where b = 4 or a = 3",
   170  			out: "select * from tbl where b = 4 or a = 3",
   171  		}, {
   172  			in:    "select * from tbl where name=%a",
   173  			binds: []*querypb.BindVariable{sqltypes.StringBindVariable("xyz")},
   174  			out:   "select * from tbl where name='xyz'",
   175  		}, {
   176  			in:    "select * from tbl where c=%a",
   177  			binds: []*querypb.BindVariable{sqltypes.Int64BindVariable(17)},
   178  			out:   "select * from tbl where c=17",
   179  		}, {
   180  			in:    "select * from tbl where name=%a and c=%a",
   181  			binds: []*querypb.BindVariable{sqltypes.StringBindVariable("xyz"), sqltypes.Int64BindVariable(17)},
   182  			out:   "select * from tbl where name='xyz' and c=17",
   183  		}, {
   184  			in:    "select * from tbl where name=%a",
   185  			binds: []*querypb.BindVariable{sqltypes.StringBindVariable("it's")},
   186  			out:   "select * from tbl where name='it\\'s'",
   187  		}, {
   188  			in:    "where name=%a",
   189  			binds: []*querypb.BindVariable{sqltypes.StringBindVariable("xyz")},
   190  			out:   "where name='xyz'",
   191  		}, {
   192  			in:    "name=%a",
   193  			binds: []*querypb.BindVariable{sqltypes.StringBindVariable("xyz")},
   194  			out:   "name='xyz'",
   195  		},
   196  	}
   197  
   198  	for _, tc := range testcases {
   199  		t.Run(tc.in, func(t *testing.T) {
   200  			query, err := ParseAndBind(tc.in, tc.binds...)
   201  			assert.NoError(t, err)
   202  			assert.Equal(t, tc.out, query)
   203  		})
   204  	}
   205  }