github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/parsers/sqlparse_test.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package parsers
    16  
    17  import (
    18  	"context"
    19  	"testing"
    20  
    21  	"github.com/stretchr/testify/require"
    22  
    23  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect"
    24  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect/mysql"
    25  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect/postgresql"
    26  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/tree"
    27  )
    28  
    29  var (
    30  	debugSQL = struct {
    31  		input  string
    32  		output string
    33  	}{
    34  		input: "use db1",
    35  	}
    36  )
    37  
    38  func TestMysql(t *testing.T) {
    39  	ctx := context.TODO()
    40  	if debugSQL.output == "" {
    41  		debugSQL.output = debugSQL.input
    42  	}
    43  	ast, err := mysql.ParseOne(ctx, debugSQL.input, 1, 0)
    44  	if err != nil {
    45  		t.Errorf("Parse(%q) err: %v", debugSQL.input, err)
    46  		return
    47  	}
    48  	out := tree.String(ast, dialect.MYSQL)
    49  	if debugSQL.output != out {
    50  		t.Errorf("Parsing failed. \nExpected/Got:\n%s\n%s", debugSQL.output, out)
    51  	}
    52  }
    53  
    54  func TestPostgresql(t *testing.T) {
    55  	ctx := context.TODO()
    56  	if debugSQL.output == "" {
    57  		debugSQL.output = debugSQL.input
    58  	}
    59  	ast, err := postgresql.ParseOne(ctx, debugSQL.input)
    60  	if err != nil {
    61  		t.Errorf("Parse(%q) err: %v", debugSQL.input, err)
    62  		return
    63  	}
    64  	out := tree.String(ast, dialect.POSTGRESQL)
    65  	if debugSQL.output != out {
    66  		t.Errorf("Parsing failed. \nExpected/Got:\n%s\n%s", debugSQL.output, out)
    67  	}
    68  }
    69  
    70  func TestSplitSqlBySemicolon(t *testing.T) {
    71  	ret := SplitSqlBySemicolon("select 1;select 2;select 3;")
    72  	require.Equal(t, 3, len(ret))
    73  	require.Equal(t, "select 1", ret[0])
    74  	require.Equal(t, "select 2", ret[1])
    75  	require.Equal(t, "select 3", ret[2])
    76  
    77  	ret = SplitSqlBySemicolon("select 1;select 2/*;;;*/;select 3;")
    78  	require.Equal(t, 3, len(ret))
    79  	require.Equal(t, "select 1", ret[0])
    80  	require.Equal(t, "select 2/*;;;*/", ret[1])
    81  	require.Equal(t, "select 3", ret[2])
    82  
    83  	ret = SplitSqlBySemicolon("select 1;select \"2;;\";select 3;")
    84  	require.Equal(t, 3, len(ret))
    85  	require.Equal(t, "select 1", ret[0])
    86  	require.Equal(t, "select \"2;;\"", ret[1])
    87  	require.Equal(t, "select 3", ret[2])
    88  
    89  	ret = SplitSqlBySemicolon("select 1;select '2;;';select 3;")
    90  	require.Equal(t, 3, len(ret))
    91  	require.Equal(t, "select 1", ret[0])
    92  	require.Equal(t, "select '2;;'", ret[1])
    93  	require.Equal(t, "select 3", ret[2])
    94  
    95  	ret = SplitSqlBySemicolon("select 1;select '2;;';select 3")
    96  	require.Equal(t, 3, len(ret))
    97  	require.Equal(t, "select 1", ret[0])
    98  	require.Equal(t, "select '2;;'", ret[1])
    99  	require.Equal(t, "select 3", ret[2])
   100  
   101  	ret = SplitSqlBySemicolon("select 1")
   102  	require.Equal(t, 1, len(ret))
   103  	require.Equal(t, "select 1", ret[0])
   104  
   105  	ret = SplitSqlBySemicolon(";;;")
   106  	require.Equal(t, 3, len(ret))
   107  	require.Equal(t, "", ret[0])
   108  	require.Equal(t, "", ret[1])
   109  	require.Equal(t, "", ret[2])
   110  
   111  	ret = SplitSqlBySemicolon(";;;  ")
   112  	require.Equal(t, 3, len(ret))
   113  	require.Equal(t, "", ret[0])
   114  	require.Equal(t, "", ret[1])
   115  	require.Equal(t, "", ret[2])
   116  
   117  	ret = SplitSqlBySemicolon(";")
   118  	require.Equal(t, 1, len(ret))
   119  	require.Equal(t, "", ret[0])
   120  
   121  	ret = SplitSqlBySemicolon("")
   122  	require.Equal(t, 1, len(ret))
   123  	require.Equal(t, "", ret[0])
   124  
   125  	ret = SplitSqlBySemicolon("   ;   ")
   126  	require.Equal(t, 1, len(ret))
   127  	require.Equal(t, "", ret[0])
   128  
   129  	ret = SplitSqlBySemicolon("   ")
   130  	require.Equal(t, 1, len(ret))
   131  	require.Equal(t, "", ret[0])
   132  
   133  	ret = SplitSqlBySemicolon("  ; /* abc */ ")
   134  	require.Equal(t, 2, len(ret))
   135  	require.Equal(t, "", ret[0])
   136  	require.Equal(t, "/* abc */", ret[1])
   137  
   138  	ret = SplitSqlBySemicolon(" /* cde */  ; /* abc */ ")
   139  	require.Equal(t, 2, len(ret))
   140  	require.Equal(t, "/* cde */", ret[0])
   141  	require.Equal(t, "/* abc */", ret[1])
   142  
   143  	ret = SplitSqlBySemicolon("   ;    ;  ")
   144  	require.Equal(t, 2, len(ret))
   145  	require.Equal(t, "", ret[0])
   146  	require.Equal(t, "", ret[1])
   147  
   148  	ret = SplitSqlBySemicolon("   ;    ;")
   149  	require.Equal(t, 2, len(ret))
   150  	require.Equal(t, "", ret[0])
   151  	require.Equal(t, "", ret[1])
   152  
   153  	ret = SplitSqlBySemicolon("   ;   ")
   154  	require.Equal(t, 1, len(ret))
   155  	require.Equal(t, "", ret[0])
   156  }
   157  
   158  func TestHandleSqlForRecord(t *testing.T) {
   159  	// Test remove /* cloud_user */ prefix
   160  	var ret []string
   161  	ret = HandleSqlForRecord(" ;   ;  ")
   162  	require.Equal(t, 2, len(ret))
   163  	require.Equal(t, "", ret[0])
   164  	require.Equal(t, "", ret[1])
   165  
   166  	ret = HandleSqlForRecord(" ; /* abc */  ")
   167  	require.Equal(t, 2, len(ret))
   168  	require.Equal(t, "", ret[0])
   169  	require.Equal(t, "/* abc */", ret[1])
   170  
   171  	ret = HandleSqlForRecord(" /* cde */  ; /* abc */ ")
   172  	require.Equal(t, 2, len(ret))
   173  	require.Equal(t, "/* cde */", ret[0])
   174  	require.Equal(t, "/* abc */", ret[1])
   175  
   176  	ret = HandleSqlForRecord(" /* cde */  ; /* abc */ ; " + stripCloudNonUser + " ; " + stripCloudUser)
   177  	require.Equal(t, 4, len(ret))
   178  	require.Equal(t, "/* cde */", ret[0])
   179  	require.Equal(t, "/* abc */", ret[1])
   180  	require.Equal(t, "", ret[2])
   181  	require.Equal(t, "", ret[3])
   182  
   183  	ret = HandleSqlForRecord("  /* cloud_user */ select 1;   ")
   184  	require.Equal(t, 1, len(ret))
   185  	require.Equal(t, "select 1", ret[0])
   186  
   187  	ret = HandleSqlForRecord("  /* cloud_user */ select 1;  ")
   188  	require.Equal(t, 1, len(ret))
   189  	require.Equal(t, "select 1", ret[0])
   190  
   191  	ret = HandleSqlForRecord("  /* cloud_user */select * from t;/* cloud_user */select * from t;/* cloud_user */select * from t;")
   192  	require.Equal(t, 3, len(ret))
   193  	require.Equal(t, "select * from t", ret[0])
   194  	require.Equal(t, "select * from t", ret[1])
   195  	require.Equal(t, "select * from t", ret[2])
   196  
   197  	ret = HandleSqlForRecord("  /* cloud_user */  select * from t ;  /* cloud_user */  select * from t ; /* cloud_user */ select * from t ; ")
   198  	require.Equal(t, 3, len(ret))
   199  	require.Equal(t, "select * from t", ret[0])
   200  	require.Equal(t, "select * from t", ret[1])
   201  	require.Equal(t, "select * from t", ret[2])
   202  
   203  	ret = HandleSqlForRecord("  /* cloud_user */  select * from t ;  /* cloud_user */  select * from t ; /* cloud_user */ select * from t ; /* abc */ ")
   204  	require.Equal(t, 4, len(ret))
   205  	require.Equal(t, "select * from t", ret[0])
   206  	require.Equal(t, "select * from t", ret[1])
   207  	require.Equal(t, "select * from t", ret[2])
   208  	require.Equal(t, "/* abc */", ret[3])
   209  
   210  	ret = HandleSqlForRecord("  /* cloud_user */  ")
   211  	require.Equal(t, 1, len(ret))
   212  	require.Equal(t, "", ret[0])
   213  
   214  	ret = HandleSqlForRecord("  /* cloud_user */   ")
   215  	require.Equal(t, 1, len(ret))
   216  	require.Equal(t, "", ret[0])
   217  
   218  	ret = HandleSqlForRecord("   " + stripCloudNonUser + "  select 1;   ")
   219  	require.Equal(t, 1, len(ret))
   220  	require.Equal(t, "select 1", ret[0])
   221  
   222  	ret = HandleSqlForRecord("  " + stripCloudNonUser + "  select * from t  ;  " + stripCloudNonUser + "   select * from t  ;   " + stripCloudNonUser + "   select * from t  ;   ")
   223  	require.Equal(t, 3, len(ret))
   224  	require.Equal(t, "select * from t", ret[0])
   225  	require.Equal(t, "select * from t", ret[1])
   226  	require.Equal(t, "select * from t", ret[2])
   227  
   228  	ret = HandleSqlForRecord("  " + stripCloudNonUser + "  select * from t  ;  " + stripCloudNonUser + "   select * from t  ;   " + stripCloudNonUser + "   select * from t  ; /* abc */  ")
   229  	require.Equal(t, 4, len(ret))
   230  	require.Equal(t, "select * from t", ret[0])
   231  	require.Equal(t, "select * from t", ret[1])
   232  	require.Equal(t, "select * from t", ret[2])
   233  	require.Equal(t, "/* abc */", ret[3])
   234  
   235  	ret = HandleSqlForRecord("   " + stripCloudNonUser + "  ")
   236  	require.Equal(t, 1, len(ret))
   237  	require.Equal(t, "", ret[0])
   238  
   239  	ret = HandleSqlForRecord("   " + stripCloudUser + "  ")
   240  	require.Equal(t, 1, len(ret))
   241  	require.Equal(t, "", ret[0])
   242  
   243  	ret = HandleSqlForRecord("")
   244  	require.Equal(t, 1, len(ret))
   245  	require.Equal(t, "", ret[0])
   246  
   247  	// Test hide secret key
   248  
   249  	ret = HandleSqlForRecord("create user u identified by '123456';")
   250  	require.Equal(t, 1, len(ret))
   251  	require.Equal(t, "create user u identified by '******'", ret[0])
   252  
   253  	ret = HandleSqlForRecord("create user u identified with '12345';")
   254  	require.Equal(t, 1, len(ret))
   255  	require.Equal(t, "create user u identified with '******'", ret[0])
   256  
   257  	ret = HandleSqlForRecord("create user u identified by random password;")
   258  	require.Equal(t, 1, len(ret))
   259  	require.Equal(t, "create user u identified by random password", ret[0])
   260  
   261  	ret = HandleSqlForRecord("create user if not exists abc1 identified by '123', abc2 identified by '234', abc3 identified with '111', abc3 identified by random password;")
   262  	require.Equal(t, 1, len(ret))
   263  	require.Equal(t, "create user if not exists abc1 identified by '******', abc2 identified by '******', abc3 identified with '******', abc3 identified by random password", ret[0])
   264  
   265  	ret = HandleSqlForRecord("create external table t (a int) URL s3option{'endpoint'='s3.us-west-2.amazonaws.com', 'access_key_id'='123', 'secret_access_key'='123', 'bucket'='test', 'filepath'='*.txt', 'region'='us-west-2'};")
   266  	require.Equal(t, 1, len(ret))
   267  	require.Equal(t, "create external table t (a int) URL s3option{'endpoint'='s3.us-west-2.amazonaws.com', 'access_key_id'='******', 'secret_access_key'='******', 'bucket'='test', 'filepath'='*.txt', 'region'='us-west-2'}", ret[0])
   268  
   269  	ret = HandleSqlForRecord("/* cloud_user *//* save_result */select count(*) from a;")
   270  	require.Equal(t, 1, len(ret))
   271  	require.Equal(t, "select count(*) from a", ret[0])
   272  
   273  	ret = HandleSqlForRecord("/* cloud_user    *//* save_result    */select count(*) from a;")
   274  	require.Equal(t, 1, len(ret))
   275  	require.Equal(t, "select count(*) from a", ret[0])
   276  
   277  	ret = HandleSqlForRecord("/* cloud_user    *//* save_result    */ /*abc */select count(*) from a;")
   278  	require.Equal(t, 1, len(ret))
   279  	require.Equal(t, "/*abc */select count(*) from a", ret[0])
   280  
   281  	ret = HandleSqlForRecord("/* cloud_user    *//* save_result    */ /*abc */select count(*) from a // def;")
   282  	require.Equal(t, 1, len(ret))
   283  	require.Equal(t, "/*abc */select count(*) from a // def;", ret[0])
   284  }