vitess.io/vitess@v0.16.2/go/vt/sqlparser/analyzer_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqlparser 18 19 import ( 20 "testing" 21 22 "github.com/stretchr/testify/assert" 23 ) 24 25 func TestPreview(t *testing.T) { 26 testcases := []struct { 27 sql string 28 want StatementType 29 }{ 30 {"select ...", StmtSelect}, 31 {" select ...", StmtSelect}, 32 {"(select ...", StmtSelect}, 33 {"( select ...", StmtSelect}, 34 {"insert ...", StmtInsert}, 35 {"replace ....", StmtReplace}, 36 {" update ...", StmtUpdate}, 37 {"Update", StmtUpdate}, 38 {"UPDATE ...", StmtUpdate}, 39 {"\n\t delete ...", StmtDelete}, 40 {"", StmtUnknown}, 41 {" ", StmtUnknown}, 42 {"begin", StmtBegin}, 43 {" begin", StmtBegin}, 44 {" begin ", StmtBegin}, 45 {"\n\t begin ", StmtBegin}, 46 {"... begin ", StmtUnknown}, 47 {"begin ...", StmtUnknown}, 48 {"begin /* ... */", StmtBegin}, 49 {"begin /* ... *//*test*/", StmtBegin}, 50 {"begin;", StmtBegin}, 51 {"begin ;", StmtBegin}, 52 {"begin; /*...*/", StmtBegin}, 53 {"start transaction", StmtBegin}, 54 {"commit", StmtCommit}, 55 {"commit /*...*/", StmtCommit}, 56 {"rollback", StmtRollback}, 57 {"rollback /*...*/", StmtRollback}, 58 {"create", StmtDDL}, 59 {"alter", StmtDDL}, 60 {"rename", StmtDDL}, 61 {"drop", StmtDDL}, 62 {"set", StmtSet}, 63 {"show", StmtShow}, 64 {"use", StmtUse}, 65 {"analyze", StmtOther}, 66 {"describe", StmtExplain}, 67 {"desc", StmtExplain}, 68 {"explain", StmtExplain}, 69 {"repair", StmtOther}, 70 {"optimize", StmtOther}, 71 {"grant", StmtPriv}, 72 {"revoke", StmtPriv}, 73 {"truncate", StmtDDL}, 74 {"flush", StmtFlush}, 75 {"unknown", StmtUnknown}, 76 77 {"/* leading comment */ select ...", StmtSelect}, 78 {"/* leading comment */ (select ...", StmtSelect}, 79 {"/* leading comment */ /* leading comment 2 */ select ...", StmtSelect}, 80 {"/*! MySQL-specific comment */", StmtComment}, 81 {"/*!50708 MySQL-version comment */", StmtComment}, 82 {"-- leading single line comment \n select ...", StmtSelect}, 83 {"-- leading single line comment \n -- leading single line comment 2\n select ...", StmtSelect}, 84 85 {"/* leading comment no end select ...", StmtUnknown}, 86 {"-- leading single line comment no end select ...", StmtUnknown}, 87 {"/*!40000 ALTER TABLE `t1` DISABLE KEYS */", StmtComment}, 88 } 89 for _, tcase := range testcases { 90 if got := Preview(tcase.sql); got != tcase.want { 91 t.Errorf("Preview(%s): %v, want %v", tcase.sql, got, tcase.want) 92 } 93 } 94 } 95 96 func TestIsDML(t *testing.T) { 97 testcases := []struct { 98 sql string 99 want bool 100 }{ 101 {" update ...", true}, 102 {"Update", true}, 103 {"UPDATE ...", true}, 104 {"\n\t delete ...", true}, 105 {"insert ...", true}, 106 {"replace ...", true}, 107 {"select ...", false}, 108 {" select ...", false}, 109 {"", false}, 110 {" ", false}, 111 } 112 for _, tcase := range testcases { 113 if got := IsDML(tcase.sql); got != tcase.want { 114 t.Errorf("IsDML(%s): %v, want %v", tcase.sql, got, tcase.want) 115 } 116 } 117 } 118 119 func TestSplitAndExpression(t *testing.T) { 120 testcases := []struct { 121 sql string 122 out []string 123 }{{ 124 sql: "select * from t", 125 out: nil, 126 }, { 127 sql: "select * from t where a = 1", 128 out: []string{"a = 1"}, 129 }, { 130 sql: "select * from t where a = 1 and b = 1", 131 out: []string{"a = 1", "b = 1"}, 132 }, { 133 sql: "select * from t where a = 1 and (b = 1 and c = 1)", 134 out: []string{"a = 1", "b = 1", "c = 1"}, 135 }, { 136 sql: "select * from t where a = 1 and (b = 1 or c = 1)", 137 out: []string{"a = 1", "b = 1 or c = 1"}, 138 }, { 139 sql: "select * from t where a = 1 and b = 1 or c = 1", 140 out: []string{"a = 1 and b = 1 or c = 1"}, 141 }, { 142 sql: "select * from t where a = 1 and b = 1 + (c = 1)", 143 out: []string{"a = 1", "b = 1 + (c = 1)"}, 144 }, { 145 sql: "select * from t where (a = 1 and ((b = 1 and c = 1)))", 146 out: []string{"a = 1", "b = 1", "c = 1"}, 147 }} 148 for _, tcase := range testcases { 149 stmt, err := Parse(tcase.sql) 150 assert.NoError(t, err) 151 var expr Expr 152 if where := stmt.(*Select).Where; where != nil { 153 expr = where.Expr 154 } 155 splits := SplitAndExpression(nil, expr) 156 var got []string 157 for _, split := range splits { 158 got = append(got, String(split)) 159 } 160 assert.Equal(t, tcase.out, got) 161 } 162 } 163 164 func TestAndExpressions(t *testing.T) { 165 greaterThanExpr := &ComparisonExpr{ 166 Operator: GreaterThanOp, 167 Left: &ColName{ 168 Name: NewIdentifierCI("val"), 169 Qualifier: TableName{ 170 Name: NewIdentifierCS("a"), 171 }, 172 }, 173 Right: &ColName{ 174 Name: NewIdentifierCI("val"), 175 Qualifier: TableName{ 176 Name: NewIdentifierCS("b"), 177 }, 178 }, 179 } 180 equalExpr := &ComparisonExpr{ 181 Operator: EqualOp, 182 Left: &ColName{ 183 Name: NewIdentifierCI("id"), 184 Qualifier: TableName{ 185 Name: NewIdentifierCS("a"), 186 }, 187 }, 188 Right: &ColName{ 189 Name: NewIdentifierCI("id"), 190 Qualifier: TableName{ 191 Name: NewIdentifierCS("b"), 192 }, 193 }, 194 } 195 testcases := []struct { 196 name string 197 expressions Exprs 198 expectedOutput Expr 199 }{ 200 { 201 name: "empty input", 202 expressions: nil, 203 expectedOutput: nil, 204 }, { 205 name: "two equal inputs", 206 expressions: Exprs{ 207 greaterThanExpr, 208 equalExpr, 209 equalExpr, 210 }, 211 expectedOutput: &AndExpr{ 212 Left: greaterThanExpr, 213 Right: equalExpr, 214 }, 215 }, 216 { 217 name: "two equal inputs", 218 expressions: Exprs{ 219 equalExpr, 220 equalExpr, 221 }, 222 expectedOutput: equalExpr, 223 }, 224 } 225 226 for _, testcase := range testcases { 227 t.Run(testcase.name, func(t *testing.T) { 228 output := AndExpressions(testcase.expressions...) 229 assert.Equal(t, String(testcase.expectedOutput), String(output)) 230 }) 231 } 232 } 233 234 func TestTableFromStatement(t *testing.T) { 235 testcases := []struct { 236 in, out string 237 }{{ 238 in: "select * from t", 239 out: "t", 240 }, { 241 in: "select * from t.t", 242 out: "t.t", 243 }, { 244 in: "select * from t1, t2", 245 out: "table expression is complex", 246 }, { 247 in: "select * from (t)", 248 out: "table expression is complex", 249 }, { 250 in: "select * from t1 join t2", 251 out: "table expression is complex", 252 }, { 253 in: "select * from (select * from t) as tt", 254 out: "table expression is complex", 255 }, { 256 in: "update t set a=1", 257 out: "unrecognized statement: update t set a=1", 258 }, { 259 in: "bad query", 260 out: "syntax error at position 4 near 'bad'", 261 }} 262 263 for _, tc := range testcases { 264 name, err := TableFromStatement(tc.in) 265 var got string 266 if err != nil { 267 got = err.Error() 268 } else { 269 got = String(name) 270 } 271 if got != tc.out { 272 t.Errorf("TableFromStatement('%s'): %s, want %s", tc.in, got, tc.out) 273 } 274 } 275 } 276 277 func TestGetTableName(t *testing.T) { 278 testcases := []struct { 279 in, out string 280 }{{ 281 in: "select * from t", 282 out: "t", 283 }, { 284 in: "select * from t.t", 285 out: "", 286 }, { 287 in: "select * from (select * from t) as tt", 288 out: "", 289 }} 290 291 for _, tc := range testcases { 292 tree, err := Parse(tc.in) 293 if err != nil { 294 t.Error(err) 295 continue 296 } 297 out := GetTableName(tree.(*Select).From[0].(*AliasedTableExpr).Expr) 298 if out.String() != tc.out { 299 t.Errorf("GetTableName('%s'): %s, want %s", tc.in, out, tc.out) 300 } 301 } 302 } 303 304 func TestIsColName(t *testing.T) { 305 testcases := []struct { 306 in Expr 307 out bool 308 }{{ 309 in: &ColName{}, 310 out: true, 311 }, { 312 in: NewHexLiteral(""), 313 }} 314 for _, tc := range testcases { 315 out := IsColName(tc.in) 316 if out != tc.out { 317 t.Errorf("IsColName(%T): %v, want %v", tc.in, out, tc.out) 318 } 319 } 320 } 321 322 func TestIsNull(t *testing.T) { 323 testcases := []struct { 324 in Expr 325 out bool 326 }{{ 327 in: &NullVal{}, 328 out: true, 329 }, { 330 in: NewStrLiteral(""), 331 }} 332 for _, tc := range testcases { 333 out := IsNull(tc.in) 334 if out != tc.out { 335 t.Errorf("IsNull(%T): %v, want %v", tc.in, out, tc.out) 336 } 337 } 338 }