vitess.io/vitess@v0.16.2/go/vt/sqlparser/parsed_query_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqlparser 18 19 import ( 20 "reflect" 21 "testing" 22 23 "vitess.io/vitess/go/sqltypes" 24 querypb "vitess.io/vitess/go/vt/proto/query" 25 26 "github.com/stretchr/testify/assert" 27 ) 28 29 func TestNewParsedQuery(t *testing.T) { 30 stmt, err := Parse("select * from a where id =:id") 31 if err != nil { 32 t.Error(err) 33 return 34 } 35 pq := NewParsedQuery(stmt) 36 want := &ParsedQuery{ 37 Query: "select * from a where id = :id", 38 bindLocations: []bindLocation{{offset: 27, length: 3}}, 39 } 40 if !reflect.DeepEqual(pq, want) { 41 t.Errorf("GenerateParsedQuery: %+v, want %+v", pq, want) 42 } 43 } 44 45 func TestGenerateQuery(t *testing.T) { 46 tcases := []struct { 47 desc string 48 query string 49 bindVars map[string]*querypb.BindVariable 50 extras map[string]Encodable 51 output string 52 }{ 53 { 54 desc: "no substitutions", 55 query: "select * from a where id = 2", 56 bindVars: map[string]*querypb.BindVariable{ 57 "id": sqltypes.Int64BindVariable(1), 58 }, 59 output: "select * from a where id = 2", 60 }, { 61 desc: "missing bind var", 62 query: "select * from a where id1 = :id1 and id2 = :id2", 63 bindVars: map[string]*querypb.BindVariable{ 64 "id1": sqltypes.Int64BindVariable(1), 65 }, 66 output: "missing bind var id2", 67 }, { 68 desc: "simple bindvar substitution", 69 query: "select * from a where id1 = :id1 and id2 = :id2", 70 bindVars: map[string]*querypb.BindVariable{ 71 "id1": sqltypes.Int64BindVariable(1), 72 "id2": sqltypes.NullBindVariable, 73 }, 74 output: "select * from a where id1 = 1 and id2 = null", 75 }, { 76 desc: "tuple *querypb.BindVariable", 77 query: "select * from a where id in ::vals", 78 bindVars: map[string]*querypb.BindVariable{ 79 "vals": sqltypes.TestBindVariable([]any{1, "aa"}), 80 }, 81 output: "select * from a where id in (1, 'aa')", 82 }, { 83 desc: "list bind vars 0 arguments", 84 query: "select * from a where id in ::vals", 85 bindVars: map[string]*querypb.BindVariable{ 86 "vals": sqltypes.TestBindVariable([]any{}), 87 }, 88 output: "empty list supplied for vals", 89 }, { 90 desc: "non-list bind var supplied", 91 query: "select * from a where id in ::vals", 92 bindVars: map[string]*querypb.BindVariable{ 93 "vals": sqltypes.Int64BindVariable(1), 94 }, 95 output: "unexpected list arg type (INT64) for key vals", 96 }, { 97 desc: "list bind var for non-list", 98 query: "select * from a where id = :vals", 99 bindVars: map[string]*querypb.BindVariable{ 100 "vals": sqltypes.TestBindVariable([]any{1}), 101 }, 102 output: "unexpected arg type (TUPLE) for non-list key vals", 103 }, { 104 desc: "single column tuple equality", 105 query: "select * from a where b = :equality", 106 extras: map[string]Encodable{ 107 "equality": &TupleEqualityList{ 108 Columns: []IdentifierCI{NewIdentifierCI("pk")}, 109 Rows: [][]sqltypes.Value{ 110 {sqltypes.NewInt64(1)}, 111 {sqltypes.NewVarBinary("aa")}, 112 }, 113 }, 114 }, 115 output: "select * from a where b = pk in (1, 'aa')", 116 }, { 117 desc: "multi column tuple equality", 118 query: "select * from a where b = :equality", 119 extras: map[string]Encodable{ 120 "equality": &TupleEqualityList{ 121 Columns: []IdentifierCI{NewIdentifierCI("pk1"), NewIdentifierCI("pk2")}, 122 Rows: [][]sqltypes.Value{ 123 { 124 sqltypes.NewInt64(1), 125 sqltypes.NewVarBinary("aa"), 126 }, 127 { 128 sqltypes.NewInt64(2), 129 sqltypes.NewVarBinary("bb"), 130 }, 131 }, 132 }, 133 }, 134 output: "select * from a where b = (pk1 = 1 and pk2 = 'aa') or (pk1 = 2 and pk2 = 'bb')", 135 }, 136 } 137 138 for _, tcase := range tcases { 139 tree, err := Parse(tcase.query) 140 if err != nil { 141 t.Errorf("parse failed for %s: %v", tcase.desc, err) 142 continue 143 } 144 buf := NewTrackedBuffer(nil) 145 buf.Myprintf("%v", tree) 146 pq := buf.ParsedQuery() 147 bytes, err := pq.GenerateQuery(tcase.bindVars, tcase.extras) 148 if err != nil { 149 assert.Equal(t, tcase.output, err.Error()) 150 } else { 151 assert.Equal(t, tcase.output, string(bytes)) 152 } 153 } 154 } 155 156 func TestParseAndBind(t *testing.T) { 157 testcases := []struct { 158 in string 159 binds []*querypb.BindVariable 160 out string 161 }{ 162 { 163 in: "select * from tbl", 164 out: "select * from tbl", 165 }, { 166 in: "select * from tbl where b=4 or a=3", 167 out: "select * from tbl where b=4 or a=3", 168 }, { 169 in: "select * from tbl where b = 4 or a = 3", 170 out: "select * from tbl where b = 4 or a = 3", 171 }, { 172 in: "select * from tbl where name=%a", 173 binds: []*querypb.BindVariable{sqltypes.StringBindVariable("xyz")}, 174 out: "select * from tbl where name='xyz'", 175 }, { 176 in: "select * from tbl where c=%a", 177 binds: []*querypb.BindVariable{sqltypes.Int64BindVariable(17)}, 178 out: "select * from tbl where c=17", 179 }, { 180 in: "select * from tbl where name=%a and c=%a", 181 binds: []*querypb.BindVariable{sqltypes.StringBindVariable("xyz"), sqltypes.Int64BindVariable(17)}, 182 out: "select * from tbl where name='xyz' and c=17", 183 }, { 184 in: "select * from tbl where name=%a", 185 binds: []*querypb.BindVariable{sqltypes.StringBindVariable("it's")}, 186 out: "select * from tbl where name='it\\'s'", 187 }, { 188 in: "where name=%a", 189 binds: []*querypb.BindVariable{sqltypes.StringBindVariable("xyz")}, 190 out: "where name='xyz'", 191 }, { 192 in: "name=%a", 193 binds: []*querypb.BindVariable{sqltypes.StringBindVariable("xyz")}, 194 out: "name='xyz'", 195 }, 196 } 197 198 for _, tc := range testcases { 199 t.Run(tc.in, func(t *testing.T) { 200 query, err := ParseAndBind(tc.in, tc.binds...) 201 assert.NoError(t, err) 202 assert.Equal(t, tc.out, query) 203 }) 204 } 205 }