vitess.io/vitess@v0.16.2/go/vt/sqlparser/comments_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"fmt"
    21  	"reflect"
    22  	"testing"
    23  
    24  	"github.com/stretchr/testify/require"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  
    28  	querypb "vitess.io/vitess/go/vt/proto/query"
    29  )
    30  
    31  func TestSplitComments(t *testing.T) {
    32  	var testCases = []struct {
    33  		input, outSQL, outLeadingComments, outTrailingComments string
    34  	}{{
    35  		input:               "/",
    36  		outSQL:              "/",
    37  		outLeadingComments:  "",
    38  		outTrailingComments: "",
    39  	}, {
    40  		input:               "*/",
    41  		outSQL:              "*/",
    42  		outLeadingComments:  "",
    43  		outTrailingComments: "",
    44  	}, {
    45  		input:               "/*/",
    46  		outSQL:              "/*/",
    47  		outLeadingComments:  "",
    48  		outTrailingComments: "",
    49  	}, {
    50  		input:               "a*/",
    51  		outSQL:              "a*/",
    52  		outLeadingComments:  "",
    53  		outTrailingComments: "",
    54  	}, {
    55  		input:               "*a*/",
    56  		outSQL:              "*a*/",
    57  		outLeadingComments:  "",
    58  		outTrailingComments: "",
    59  	}, {
    60  		input:               "**a*/",
    61  		outSQL:              "**a*/",
    62  		outLeadingComments:  "",
    63  		outTrailingComments: "",
    64  	}, {
    65  		input:               "/*b**a*/",
    66  		outSQL:              "",
    67  		outLeadingComments:  "",
    68  		outTrailingComments: "/*b**a*/",
    69  	}, {
    70  		input:               "/*a*/",
    71  		outSQL:              "",
    72  		outLeadingComments:  "",
    73  		outTrailingComments: "/*a*/",
    74  	}, {
    75  		input:               "/**/",
    76  		outSQL:              "",
    77  		outLeadingComments:  "",
    78  		outTrailingComments: "/**/",
    79  	}, {
    80  		input:               "/*b*/ /*a*/",
    81  		outSQL:              "",
    82  		outLeadingComments:  "",
    83  		outTrailingComments: "/*b*/ /*a*/",
    84  	}, {
    85  		input:               "/* before */ foo /* bar */",
    86  		outSQL:              "foo",
    87  		outLeadingComments:  "/* before */ ",
    88  		outTrailingComments: " /* bar */",
    89  	}, {
    90  		input:               "/* before1 */ /* before2 */ foo /* after1 */ /* after2 */",
    91  		outSQL:              "foo",
    92  		outLeadingComments:  "/* before1 */ /* before2 */ ",
    93  		outTrailingComments: " /* after1 */ /* after2 */",
    94  	}, {
    95  		input:               "/** before */ foo /** bar */",
    96  		outSQL:              "foo",
    97  		outLeadingComments:  "/** before */ ",
    98  		outTrailingComments: " /** bar */",
    99  	}, {
   100  		input:               "/*** before */ foo /*** bar */",
   101  		outSQL:              "foo",
   102  		outLeadingComments:  "/*** before */ ",
   103  		outTrailingComments: " /*** bar */",
   104  	}, {
   105  		input:               "/** before **/ foo /** bar **/",
   106  		outSQL:              "foo",
   107  		outLeadingComments:  "/** before **/ ",
   108  		outTrailingComments: " /** bar **/",
   109  	}, {
   110  		input:               "/*** before ***/ foo /*** bar ***/",
   111  		outSQL:              "foo",
   112  		outLeadingComments:  "/*** before ***/ ",
   113  		outTrailingComments: " /*** bar ***/",
   114  	}, {
   115  		input:               " /*** before ***/ foo /*** bar ***/ ",
   116  		outSQL:              "foo",
   117  		outLeadingComments:  "/*** before ***/ ",
   118  		outTrailingComments: " /*** bar ***/",
   119  	}, {
   120  		input:               "*** bar ***/",
   121  		outSQL:              "*** bar ***/",
   122  		outLeadingComments:  "",
   123  		outTrailingComments: "",
   124  	}, {
   125  		input:               " foo ",
   126  		outSQL:              "foo",
   127  		outLeadingComments:  "",
   128  		outTrailingComments: "",
   129  	}, {
   130  		input:               "select 1 from t where col = '*//*'",
   131  		outSQL:              "select 1 from t where col = '*//*'",
   132  		outLeadingComments:  "",
   133  		outTrailingComments: "",
   134  	}, {
   135  		input:               "/*! select 1 */",
   136  		outSQL:              "/*! select 1 */",
   137  		outLeadingComments:  "",
   138  		outTrailingComments: "",
   139  	}}
   140  	for _, testCase := range testCases {
   141  		t.Run(testCase.input, func(t *testing.T) {
   142  			gotSQL, gotComments := SplitMarginComments(testCase.input)
   143  			gotLeadingComments, gotTrailingComments := gotComments.Leading, gotComments.Trailing
   144  
   145  			if gotSQL != testCase.outSQL {
   146  				t.Errorf("test input: '%s', got SQL\n%+v, want\n%+v", testCase.input, gotSQL, testCase.outSQL)
   147  			}
   148  			if gotLeadingComments != testCase.outLeadingComments {
   149  				t.Errorf("test input: '%s', got LeadingComments\n%+v, want\n%+v", testCase.input, gotLeadingComments, testCase.outLeadingComments)
   150  			}
   151  			if gotTrailingComments != testCase.outTrailingComments {
   152  				t.Errorf("test input: '%s', got TrailingComments\n%+v, want\n%+v", testCase.input, gotTrailingComments, testCase.outTrailingComments)
   153  			}
   154  		})
   155  	}
   156  }
   157  
   158  func TestStripLeadingComments(t *testing.T) {
   159  	var testCases = []struct {
   160  		input, outSQL string
   161  	}{{
   162  		input:  "/",
   163  		outSQL: "/",
   164  	}, {
   165  		input:  "*/",
   166  		outSQL: "*/",
   167  	}, {
   168  		input:  "/*/",
   169  		outSQL: "/*/",
   170  	}, {
   171  		input:  "/*a",
   172  		outSQL: "/*a",
   173  	}, {
   174  		input:  "/*a*",
   175  		outSQL: "/*a*",
   176  	}, {
   177  		input:  "/*a**",
   178  		outSQL: "/*a**",
   179  	}, {
   180  		input:  "/*b**a*/",
   181  		outSQL: "",
   182  	}, {
   183  		input:  "/*a*/",
   184  		outSQL: "",
   185  	}, {
   186  		input:  "/**/",
   187  		outSQL: "",
   188  	}, {
   189  		input:  "/*!*/",
   190  		outSQL: "/*!*/",
   191  	}, {
   192  		input:  "/*!a*/",
   193  		outSQL: "/*!a*/",
   194  	}, {
   195  		input:  "/*b*/ /*a*/",
   196  		outSQL: "",
   197  	}, {
   198  		input: `/*b*/ --foo
   199  bar`,
   200  		outSQL: "bar",
   201  	}, {
   202  		input:  "foo /* bar */",
   203  		outSQL: "foo /* bar */",
   204  	}, {
   205  		input:  "/* foo */ bar",
   206  		outSQL: "bar",
   207  	}, {
   208  		input:  "-- /* foo */ bar",
   209  		outSQL: "",
   210  	}, {
   211  		input:  "foo -- bar */",
   212  		outSQL: "foo -- bar */",
   213  	}, {
   214  		input: `/*
   215  foo */ bar`,
   216  		outSQL: "bar",
   217  	}, {
   218  		input: `-- foo bar
   219  a`,
   220  		outSQL: "a",
   221  	}, {
   222  		input:  `-- foo bar`,
   223  		outSQL: "",
   224  	}}
   225  	for _, testCase := range testCases {
   226  		gotSQL := StripLeadingComments(testCase.input)
   227  
   228  		if gotSQL != testCase.outSQL {
   229  			t.Errorf("test input: '%s', got SQL\n%+v, want\n%+v", testCase.input, gotSQL, testCase.outSQL)
   230  		}
   231  	}
   232  }
   233  
   234  func TestExtractMysqlComment(t *testing.T) {
   235  	var testCases = []struct {
   236  		input, outSQL, outVersion string
   237  	}{{
   238  		input:      "/*!50708SET max_execution_time=5000 */",
   239  		outSQL:     "SET max_execution_time=5000",
   240  		outVersion: "50708",
   241  	}, {
   242  		input:      "/*!50708 SET max_execution_time=5000*/",
   243  		outSQL:     "SET max_execution_time=5000",
   244  		outVersion: "50708",
   245  	}, {
   246  		input:      "/*!50708* from*/",
   247  		outSQL:     "* from",
   248  		outVersion: "50708",
   249  	}, {
   250  		input:      "/*! SET max_execution_time=5000*/",
   251  		outSQL:     "SET max_execution_time=5000",
   252  		outVersion: "",
   253  	}}
   254  	for _, testCase := range testCases {
   255  		gotVersion, gotSQL := ExtractMysqlComment(testCase.input)
   256  
   257  		if gotVersion != testCase.outVersion {
   258  			t.Errorf("test input: '%s', got version\n%+v, want\n%+v", testCase.input, gotVersion, testCase.outVersion)
   259  		}
   260  		if gotSQL != testCase.outSQL {
   261  			t.Errorf("test input: '%s', got SQL\n%+v, want\n%+v", testCase.input, gotSQL, testCase.outSQL)
   262  		}
   263  	}
   264  }
   265  
   266  func TestExtractCommentDirectives(t *testing.T) {
   267  	var testCases = []struct {
   268  		input string
   269  		vals  map[string]string
   270  	}{{
   271  		input: "",
   272  		vals:  nil,
   273  	}, {
   274  		input: "/* not a vt comment */",
   275  		vals:  map[string]string{},
   276  	}, {
   277  		input: "/*vt+ */",
   278  		vals:  map[string]string{},
   279  	}, {
   280  		input: "/*vt+ SINGLE_OPTION */",
   281  		vals: map[string]string{
   282  			"single_option": "true",
   283  		},
   284  	}, {
   285  		input: "/*vt+ ONE_OPT TWO_OPT */",
   286  		vals: map[string]string{
   287  			"one_opt": "true",
   288  			"two_opt": "true",
   289  		},
   290  	}, {
   291  		input: "/*vt+ ONE_OPT */ /* other comment */ /*vt+ TWO_OPT */",
   292  		vals: map[string]string{
   293  			"one_opt": "true",
   294  			"two_opt": "true",
   295  		},
   296  	}, {
   297  		input: "/*vt+ ONE_OPT=abc TWO_OPT=def */",
   298  		vals: map[string]string{
   299  			"one_opt": "abc",
   300  			"two_opt": "def",
   301  		},
   302  	}, {
   303  		input: "/*vt+ ONE_OPT=true TWO_OPT=false */",
   304  		vals: map[string]string{
   305  			"one_opt": "true",
   306  			"two_opt": "false",
   307  		},
   308  	}, {
   309  		input: "/*vt+ ONE_OPT=true TWO_OPT=\"false\" */",
   310  		vals: map[string]string{
   311  			"one_opt": "true",
   312  			"two_opt": "\"false\"",
   313  		},
   314  	}, {
   315  		input: "/*vt+ RANGE_OPT=[a:b] ANOTHER ANOTHER_WITH_VALEQ=val= AND_ONE_WITH_EQ== */",
   316  		vals: map[string]string{
   317  			"range_opt":          "[a:b]",
   318  			"another":            "true",
   319  			"another_with_valeq": "val=",
   320  			"and_one_with_eq":    "=",
   321  		},
   322  	}}
   323  
   324  	for _, testCase := range testCases {
   325  		t.Run(testCase.input, func(t *testing.T) {
   326  			sqls := []string{
   327  				"select " + testCase.input + " 1 from dual",
   328  				"update " + testCase.input + " t set i=i+1",
   329  				"delete " + testCase.input + " from t where id>1",
   330  				"drop " + testCase.input + " table t",
   331  				"create " + testCase.input + " table if not exists t (id int primary key)",
   332  				"alter " + testCase.input + " table t add column c int not null",
   333  				"create " + testCase.input + " view v as select * from t",
   334  				"create " + testCase.input + " or replace view v as select * from t",
   335  				"alter " + testCase.input + " view v as select * from t",
   336  				"drop " + testCase.input + " view v",
   337  			}
   338  			for _, sql := range sqls {
   339  				t.Run(sql, func(t *testing.T) {
   340  					var comments *ParsedComments
   341  					stmt, _ := Parse(sql)
   342  					switch s := stmt.(type) {
   343  					case *Select:
   344  						comments = s.Comments
   345  					case *Update:
   346  						comments = s.Comments
   347  					case *Delete:
   348  						comments = s.Comments
   349  					case *DropTable:
   350  						comments = s.Comments
   351  					case *AlterTable:
   352  						comments = s.Comments
   353  					case *CreateTable:
   354  						comments = s.Comments
   355  					case *CreateView:
   356  						comments = s.Comments
   357  					case *AlterView:
   358  						comments = s.Comments
   359  					case *DropView:
   360  						comments = s.Comments
   361  					default:
   362  						t.Errorf("Unexpected statement type %+v", s)
   363  					}
   364  
   365  					vals := comments.Directives()
   366  					if vals == nil {
   367  						require.Nil(t, vals)
   368  						return
   369  					}
   370  					if !reflect.DeepEqual(vals.m, testCase.vals) {
   371  						t.Errorf("test input: '%v', got vals %T:\n%+v, want %T\n%+v", testCase.input, vals, vals, testCase.vals, testCase.vals)
   372  					}
   373  				})
   374  			}
   375  		})
   376  	}
   377  
   378  	d := &CommentDirectives{m: map[string]string{
   379  		"one_opt": "true",
   380  		"two_opt": "false",
   381  		"three":   "1",
   382  		"four":    "2",
   383  		"five":    "0",
   384  		"six":     "true",
   385  	}}
   386  
   387  	assert.True(t, d.IsSet("ONE_OPT"), "d.IsSet(ONE_OPT)")
   388  	assert.False(t, d.IsSet("TWO_OPT"), "d.IsSet(TWO_OPT)")
   389  	assert.True(t, d.IsSet("three"), "d.IsSet(three)")
   390  	assert.False(t, d.IsSet("four"), "d.IsSet(four)")
   391  	assert.False(t, d.IsSet("five"), "d.IsSet(five)")
   392  	assert.True(t, d.IsSet("six"), "d.IsSet(six)")
   393  }
   394  
   395  func TestSkipQueryPlanCacheDirective(t *testing.T) {
   396  	stmt, _ := Parse("insert /*vt+ SKIP_QUERY_PLAN_CACHE=1 */ into user(id) values (1), (2)")
   397  	if !SkipQueryPlanCacheDirective(stmt) {
   398  		t.Errorf("d.SkipQueryPlanCacheDirective(stmt) should be true")
   399  	}
   400  
   401  	stmt, _ = Parse("insert into user(id) values (1), (2)")
   402  	if SkipQueryPlanCacheDirective(stmt) {
   403  		t.Errorf("d.SkipQueryPlanCacheDirective(stmt) should be false")
   404  	}
   405  
   406  	stmt, _ = Parse("update /*vt+ SKIP_QUERY_PLAN_CACHE=1 */ users set name=1")
   407  	if !SkipQueryPlanCacheDirective(stmt) {
   408  		t.Errorf("d.SkipQueryPlanCacheDirective(stmt) should be true")
   409  	}
   410  
   411  	stmt, _ = Parse("select /*vt+ SKIP_QUERY_PLAN_CACHE=1 */ * from users")
   412  	if !SkipQueryPlanCacheDirective(stmt) {
   413  		t.Errorf("d.SkipQueryPlanCacheDirective(stmt) should be true")
   414  	}
   415  
   416  	stmt, _ = Parse("delete /*vt+ SKIP_QUERY_PLAN_CACHE=1 */ from users")
   417  	if !SkipQueryPlanCacheDirective(stmt) {
   418  		t.Errorf("d.SkipQueryPlanCacheDirective(stmt) should be true")
   419  	}
   420  }
   421  
   422  func TestIgnoreMaxPayloadSizeDirective(t *testing.T) {
   423  	testCases := []struct {
   424  		query    string
   425  		expected bool
   426  	}{
   427  		{"insert /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ into user(id) values (1), (2)", true},
   428  		{"insert into user(id) values (1), (2)", false},
   429  		{"update /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ users set name=1", true},
   430  		{"update users set name=1", false},
   431  		{"select /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ * from users", true},
   432  		{"select * from users", false},
   433  		{"delete /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ from users", true},
   434  		{"delete from users", false},
   435  		{"show /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ create table users", false},
   436  		{"show create table users", false},
   437  	}
   438  
   439  	for _, test := range testCases {
   440  		t.Run(test.query, func(t *testing.T) {
   441  			stmt, _ := Parse(test.query)
   442  			got := IgnoreMaxPayloadSizeDirective(stmt)
   443  			assert.Equalf(t, test.expected, got, fmt.Sprintf("IgnoreMaxPayloadSizeDirective(stmt) returned %v but expected %v", got, test.expected))
   444  		})
   445  	}
   446  }
   447  
   448  func TestIgnoreMaxMaxMemoryRowsDirective(t *testing.T) {
   449  	testCases := []struct {
   450  		query    string
   451  		expected bool
   452  	}{
   453  		{"insert /*vt+ IGNORE_MAX_MEMORY_ROWS=1 */ into user(id) values (1), (2)", true},
   454  		{"insert into user(id) values (1), (2)", false},
   455  		{"update /*vt+ IGNORE_MAX_MEMORY_ROWS=1 */ users set name=1", true},
   456  		{"update users set name=1", false},
   457  		{"select /*vt+ IGNORE_MAX_MEMORY_ROWS=1 */ * from users", true},
   458  		{"select * from users", false},
   459  		{"delete /*vt+ IGNORE_MAX_MEMORY_ROWS=1 */ from users", true},
   460  		{"delete from users", false},
   461  		{"show /*vt+ IGNORE_MAX_MEMORY_ROWS=1 */ create table users", false},
   462  		{"show create table users", false},
   463  	}
   464  
   465  	for _, test := range testCases {
   466  		t.Run(test.query, func(t *testing.T) {
   467  			stmt, _ := Parse(test.query)
   468  			got := IgnoreMaxMaxMemoryRowsDirective(stmt)
   469  			assert.Equalf(t, test.expected, got, fmt.Sprintf("IgnoreMaxPayloadSizeDirective(stmt) returned %v but expected %v", got, test.expected))
   470  		})
   471  	}
   472  }
   473  
   474  func TestConsolidator(t *testing.T) {
   475  	testCases := []struct {
   476  		query    string
   477  		expected querypb.ExecuteOptions_Consolidator
   478  	}{
   479  		{"insert /*vt+ CONSOLIDATOR=enabled */ into user(id) values (1), (2)", querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED},
   480  		{"update /*vt+ CONSOLIDATOR=enabled */ users set name=1", querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED},
   481  		{"delete /*vt+ CONSOLIDATOR=enabled */ from users", querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED},
   482  		{"show /*vt+ CONSOLIDATOR=enabled */ create table users", querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED},
   483  		{"select * from users", querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED},
   484  		{"select /*vt+ CONSOLIDATOR=invalid_value */ * from users", querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED},
   485  		{"select /*vt+ IGNORE_MAX_MEMORY_ROWS=1 */ * from users", querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED},
   486  		{"select /*vt+ CONSOLIDATOR=disabled */ * from users", querypb.ExecuteOptions_CONSOLIDATOR_DISABLED},
   487  		{"select /*vt+ CONSOLIDATOR=enabled */ * from users", querypb.ExecuteOptions_CONSOLIDATOR_ENABLED},
   488  		{"select /*vt+ CONSOLIDATOR=enabled_replicas */ * from users", querypb.ExecuteOptions_CONSOLIDATOR_ENABLED_REPLICAS},
   489  	}
   490  
   491  	for _, test := range testCases {
   492  		t.Run(test.query, func(t *testing.T) {
   493  			stmt, _ := Parse(test.query)
   494  			got := Consolidator(stmt)
   495  			assert.Equalf(t, test.expected, got, fmt.Sprintf("Consolidator(stmt) returned %v but expected %v", got, test.expected))
   496  		})
   497  	}
   498  }