github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/plan/build_test.go (about)

     1  // Copyright 2021 - 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"encoding/json"
    21  	"github.com/stretchr/testify/assert"
    22  	"os"
    23  	"strings"
    24  	"testing"
    25  
    26  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    27  
    28  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    29  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect/mysql"
    30  )
    31  
    32  // only use in developing
    33  func TestSingleSQL(t *testing.T) {
    34  	sql := "select * from nation"
    35  
    36  	mock := NewMockOptimizer(false)
    37  	logicPlan, err := runOneStmt(mock, t, sql)
    38  	if err != nil {
    39  		t.Fatalf("%+v", err)
    40  	}
    41  	outPutPlan(logicPlan, true, t)
    42  }
    43  
    44  //Test Query Node Tree
    45  // func TestNodeTree(t *testing.T) {
    46  // 	type queryCheck struct {
    47  // 		steps    []int32                    //steps
    48  // 		nodeType map[int]plan.Node_NodeType //node_type in each node
    49  // 		children map[int][]int32            //children in each node
    50  // 	}
    51  
    52  // 	// map[sql string]checkData
    53  // 	nodeTreeCheckList := map[string]queryCheck{
    54  // 		"SELECT -1": {
    55  // 			steps: []int32{0},
    56  // 			nodeType: map[int]plan.Node_NodeType{
    57  // 				0: plan.Node_VALUE_SCAN,
    58  // 			},
    59  // 			children: nil,
    60  // 		},
    61  // 		"SELECT -1 from dual": {
    62  // 			steps: []int32{0},
    63  // 			nodeType: map[int]plan.Node_NodeType{
    64  // 				0: plan.Node_VALUE_SCAN,
    65  // 			},
    66  // 			children: nil,
    67  // 		},
    68  // 		// one node
    69  // 		"SELECT N_NAME FROM NATION WHERE N_REGIONKEY = 3": {
    70  // 			steps: []int32{0},
    71  // 			nodeType: map[int]plan.Node_NodeType{
    72  // 				0: plan.Node_TABLE_SCAN,
    73  // 			},
    74  // 			children: nil,
    75  // 		},
    76  // 		// two nodes- SCAN + SORT
    77  // 		"SELECT N_NAME FROM NATION WHERE N_REGIONKEY = 3 Order By N_REGIONKEY": {
    78  // 			steps: []int32{1},
    79  // 			nodeType: map[int]plan.Node_NodeType{
    80  // 				0: plan.Node_TABLE_SCAN,
    81  // 				1: plan.Node_SORT,
    82  // 			},
    83  // 			children: map[int][]int32{
    84  // 				1: {0},
    85  // 			},
    86  // 		},
    87  // 		// two nodes- SCAN + AGG(group by)
    88  // 		"SELECT N_NAME FROM NATION WHERE N_REGIONKEY = 3 Group By N_NAME": {
    89  // 			steps: []int32{1},
    90  // 			nodeType: map[int]plan.Node_NodeType{
    91  // 				0: plan.Node_TABLE_SCAN,
    92  // 				1: plan.Node_AGG,
    93  // 			},
    94  // 			children: map[int][]int32{
    95  // 				1: {0},
    96  // 			},
    97  // 		},
    98  // 		"select sum(n_nationkey) from nation": {
    99  // 			steps: []int32{1},
   100  // 			nodeType: map[int]plan.Node_NodeType{
   101  // 				0: plan.Node_TABLE_SCAN,
   102  // 				1: plan.Node_AGG,
   103  // 			},
   104  // 			children: map[int][]int32{
   105  // 				1: {0},
   106  // 			},
   107  // 		},
   108  // 		"select sum(n_nationkey) from nation order by sum(n_nationkey)": {
   109  // 			steps: []int32{2},
   110  // 			nodeType: map[int]plan.Node_NodeType{
   111  // 				0: plan.Node_TABLE_SCAN,
   112  // 				1: plan.Node_AGG,
   113  // 				2: plan.Node_SORT,
   114  // 			},
   115  // 			children: map[int][]int32{
   116  // 				1: {0},
   117  // 				2: {1},
   118  // 			},
   119  // 		},
   120  // 		// two nodes- SCAN + AGG(distinct)
   121  // 		"SELECT distinct N_NAME FROM NATION": {
   122  // 			steps: []int32{1},
   123  // 			nodeType: map[int]plan.Node_NodeType{
   124  // 				0: plan.Node_TABLE_SCAN,
   125  // 				1: plan.Node_AGG,
   126  // 			},
   127  // 			children: map[int][]int32{
   128  // 				1: {0},
   129  // 			},
   130  // 		},
   131  // 		// three nodes- SCAN + AGG(group by) + SORT
   132  // 		"SELECT N_NAME, count(*) as ttl FROM NATION Group By N_NAME Order By ttl": {
   133  // 			steps: []int32{2},
   134  // 			nodeType: map[int]plan.Node_NodeType{
   135  // 				0: plan.Node_TABLE_SCAN,
   136  // 				1: plan.Node_AGG,
   137  // 				2: plan.Node_SORT,
   138  // 			},
   139  // 			children: map[int][]int32{
   140  // 				1: {0},
   141  // 				2: {1},
   142  // 			},
   143  // 		},
   144  // 		// three nodes - SCAN, SCAN, JOIN
   145  // 		"SELECT N_NAME, N_REGIONKEY FROM NATION join REGION on NATION.N_REGIONKEY = REGION.R_REGIONKEY": {
   146  // 			steps: []int32{3},
   147  // 			nodeType: map[int]plan.Node_NodeType{
   148  // 				0: plan.Node_TABLE_SCAN,
   149  // 				1: plan.Node_TABLE_SCAN,
   150  // 				2: plan.Node_JOIN,
   151  // 				3: plan.Node_PROJECT,
   152  // 			},
   153  // 			children: map[int][]int32{
   154  // 				2: {0, 1},
   155  // 			},
   156  // 		},
   157  // 		// three nodes - SCAN, SCAN, JOIN  //use where for join condition
   158  // 		"SELECT N_NAME, N_REGIONKEY FROM NATION, REGION WHERE NATION.N_REGIONKEY = REGION.R_REGIONKEY": {
   159  // 			steps: []int32{3},
   160  // 			nodeType: map[int]plan.Node_NodeType{
   161  // 				0: plan.Node_TABLE_SCAN,
   162  // 				1: plan.Node_TABLE_SCAN,
   163  // 				2: plan.Node_JOIN,
   164  // 				3: plan.Node_PROJECT,
   165  // 			},
   166  // 			children: map[int][]int32{
   167  // 				2: {0, 1},
   168  // 				3: {2},
   169  // 			},
   170  // 		},
   171  // 		// 5 nodes - SCAN, SCAN, JOIN, SCAN, JOIN  //join three table
   172  // 		"SELECT l.L_ORDERKEY FROM CUSTOMER c, ORDERS o, LINEITEM l WHERE c.C_CUSTKEY = o.O_CUSTKEY and l.L_ORDERKEY = o.O_ORDERKEY and o.O_ORDERKEY < 10": {
   173  // 			steps: []int32{6},
   174  // 			nodeType: map[int]plan.Node_NodeType{
   175  // 				0: plan.Node_TABLE_SCAN,
   176  // 				1: plan.Node_TABLE_SCAN,
   177  // 				2: plan.Node_JOIN,
   178  // 				3: plan.Node_PROJECT,
   179  // 				4: plan.Node_TABLE_SCAN,
   180  // 				5: plan.Node_JOIN,
   181  // 				6: plan.Node_PROJECT,
   182  // 			},
   183  // 			children: map[int][]int32{
   184  // 				2: {0, 1},
   185  // 				3: {2},
   186  // 				5: {3, 4},
   187  // 				6: {5},
   188  // 			},
   189  // 		},
   190  // 		// 6 nodes - SCAN, SCAN, JOIN, SCAN, JOIN, SORT  //join three table
   191  // 		"SELECT l.L_ORDERKEY FROM CUSTOMER c, ORDERS o, LINEITEM l WHERE c.C_CUSTKEY = o.O_CUSTKEY and l.L_ORDERKEY = o.O_ORDERKEY and o.O_ORDERKEY < 10 order by c.C_CUSTKEY": {
   192  // 			steps: []int32{7},
   193  // 			nodeType: map[int]plan.Node_NodeType{
   194  // 				0: plan.Node_TABLE_SCAN,
   195  // 				1: plan.Node_TABLE_SCAN,
   196  // 				2: plan.Node_JOIN,
   197  // 				3: plan.Node_PROJECT,
   198  // 				4: plan.Node_TABLE_SCAN,
   199  // 				5: plan.Node_JOIN,
   200  // 				6: plan.Node_PROJECT,
   201  // 				7: plan.Node_SORT,
   202  // 			},
   203  // 			children: map[int][]int32{
   204  // 				2: {0, 1},
   205  // 				3: {2},
   206  // 				5: {3, 4},
   207  // 				6: {5},
   208  // 				7: {6},
   209  // 			},
   210  // 		},
   211  // 		// 3 nodes  //Derived table
   212  // 		"select c_custkey from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey) a where ff > 0": {
   213  // 			steps: []int32{2},
   214  // 			nodeType: map[int]plan.Node_NodeType{
   215  // 				0: plan.Node_TABLE_SCAN,
   216  // 				1: plan.Node_AGG,
   217  // 				2: plan.Node_PROJECT,
   218  // 			},
   219  // 			children: map[int][]int32{
   220  // 				1: {0},
   221  // 				2: {1},
   222  // 			},
   223  // 		},
   224  // 		// 4 nodes  //Derived table
   225  // 		"select c_custkey from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey ) a where ff > 0 order by c_custkey": {
   226  // 			steps: []int32{3},
   227  // 			nodeType: map[int]plan.Node_NodeType{
   228  // 				0: plan.Node_TABLE_SCAN,
   229  // 				1: plan.Node_AGG,
   230  // 				2: plan.Node_PROJECT,
   231  // 				3: plan.Node_SORT,
   232  // 			},
   233  // 			children: map[int][]int32{
   234  // 				1: {0},
   235  // 				2: {1},
   236  // 				3: {2},
   237  // 			},
   238  // 		},
   239  // 		// Derived table join normal table
   240  // 		"select c_custkey from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey ) a join NATION b on a.c_custkey = b.N_REGIONKEY where b.N_NATIONKEY > 10 order By b.N_REGIONKEY": {
   241  // 			steps: []int32{6},
   242  // 			nodeType: map[int]plan.Node_NodeType{
   243  // 				0: plan.Node_TABLE_SCAN,
   244  // 				1: plan.Node_AGG,
   245  // 				2: plan.Node_PROJECT,
   246  // 				3: plan.Node_TABLE_SCAN,
   247  // 				4: plan.Node_JOIN,
   248  // 				5: plan.Node_PROJECT,
   249  // 				6: plan.Node_SORT,
   250  // 			},
   251  // 			children: map[int][]int32{
   252  // 				1: {0},
   253  // 				2: {1},
   254  // 				4: {2, 3},
   255  // 				5: {4},
   256  // 				6: {5},
   257  // 			},
   258  // 		},
   259  // 		// insert from values
   260  // 		"INSERT NATION (N_NATIONKEY, N_REGIONKEY, N_NAME) VALUES (1, 21, 'NAME1'), (2, 22, 'NAME2')": {
   261  // 			steps: []int32{1},
   262  // 			nodeType: map[int]plan.Node_NodeType{
   263  // 				0: plan.Node_VALUE_SCAN,
   264  // 				1: plan.Node_INSERT,
   265  // 			},
   266  // 			children: map[int][]int32{
   267  // 				1: {0},
   268  // 			},
   269  // 		},
   270  // 		// insert from select
   271  // 		"INSERT NATION SELECT * FROM NATION2": {
   272  // 			steps: []int32{1},
   273  // 			nodeType: map[int]plan.Node_NodeType{
   274  // 				0: plan.Node_TABLE_SCAN,
   275  // 				1: plan.Node_INSERT,
   276  // 			},
   277  // 			children: map[int][]int32{
   278  // 				1: {0},
   279  // 			},
   280  // 		},
   281  // 		// update
   282  // 		"UPDATE NATION SET N_NAME ='U1', N_REGIONKEY=N_REGIONKEY+2 WHERE N_NATIONKEY > 10 LIMIT 20": {
   283  // 			steps: []int32{1},
   284  // 			nodeType: map[int]plan.Node_NodeType{
   285  // 				0: plan.Node_TABLE_SCAN,
   286  // 				1: plan.Node_UPDATE,
   287  // 			},
   288  // 			children: map[int][]int32{
   289  // 				1: {0},
   290  // 			},
   291  // 		},
   292  // 		// delete
   293  // 		"DELETE FROM NATION WHERE N_NATIONKEY > 10 LIMIT 20": {
   294  // 			steps: []int32{1},
   295  // 			nodeType: map[int]plan.Node_NodeType{
   296  // 				0: plan.Node_TABLE_SCAN,
   297  // 				1: plan.Node_DELETE,
   298  // 			},
   299  // 		},
   300  // 		// uncorrelated subquery
   301  // 		"SELECT * FROM NATION where N_REGIONKEY > (select max(R_REGIONKEY) from REGION)": {
   302  // 			steps: []int32{0},
   303  // 			nodeType: map[int]plan.Node_NodeType{
   304  // 				0: plan.Node_TABLE_SCAN, //nodeid = 1  here is the subquery
   305  // 				1: plan.Node_TABLE_SCAN, //nodeid = 0, here is SELECT * FROM NATION where N_REGIONKEY > [subquery]
   306  // 			},
   307  // 			children: map[int][]int32{},
   308  // 		},
   309  // 		// correlated subquery
   310  // 		`SELECT * FROM NATION where N_REGIONKEY >
   311  // 			(select avg(R_REGIONKEY) from REGION where R_REGIONKEY < N_REGIONKEY group by R_NAME)
   312  // 		order by N_NATIONKEY`: {
   313  // 			steps: []int32{3},
   314  // 			nodeType: map[int]plan.Node_NodeType{
   315  // 				0: plan.Node_TABLE_SCAN, //nodeid = 1  subquery node,so,wo pop it to top
   316  // 				1: plan.Node_TABLE_SCAN, //nodeid = 0
   317  // 				2: plan.Node_AGG,        //nodeid = 2  subquery node,so,wo pop it to top
   318  // 				3: plan.Node_SORT,       //nodeid = 3
   319  // 			},
   320  // 			children: map[int][]int32{
   321  // 				2: {1}, //nodeid = 2, have children(NodeId=1, position=0)
   322  // 				3: {0}, //nodeid = 3, have children(NodeId=0, position=2)
   323  // 			},
   324  // 		},
   325  // 		// cte
   326  // 		`with tbl(col1, col2) as (select n_nationkey, n_name from nation) select * from tbl order by col2`: {
   327  // 			steps: []int32{1, 3},
   328  // 			nodeType: map[int]plan.Node_NodeType{
   329  // 				0: plan.Node_TABLE_SCAN,
   330  // 				1: plan.Node_MATERIAL,
   331  // 				2: plan.Node_MATERIAL_SCAN,
   332  // 				3: plan.Node_SORT,
   333  // 			},
   334  // 			children: map[int][]int32{
   335  // 				1: {0},
   336  // 				3: {2},
   337  // 			},
   338  // 		},
   339  // 	}
   340  
   341  // 	// run test and check node tree
   342  // 	for sql, check := range nodeTreeCheckList {
   343  // 		mock := NewMockOptimizer(false)
   344  // 		logicPlan, err := runOneStmt(mock, t, sql)
   345  // 		query := logicPlan.GetQuery()
   346  // 		if err != nil {
   347  // 			t.Fatalf("%+v, sql=%v", err, sql)
   348  // 		}
   349  // 		if len(query.Steps) != len(check.steps) {
   350  // 			t.Fatalf("run sql[%+v] error, root should be [%+v] but now is [%+v]", sql, check.steps, query.Steps)
   351  // 		}
   352  // 		for idx, step := range query.Steps {
   353  // 			if step != check.steps[idx] {
   354  // 				t.Fatalf("run sql[%+v] error, root should be [%+v] but now is [%+v]", sql, check.steps, query.Steps)
   355  // 			}
   356  // 		}
   357  // 		for idx, typ := range check.nodeType {
   358  // 			if idx >= len(query.Nodes) {
   359  // 				t.Fatalf("run sql[%+v] error, query.Nodes[%+v].NodeType not exist", sql, idx)
   360  // 			}
   361  // 			if query.Nodes[idx].NodeType != typ {
   362  // 				t.Fatalf("run sql[%+v] error, query.Nodes[%+v].NodeType should be [%+v] but now is [%+v]", sql, idx, typ, query.Nodes[idx].NodeType)
   363  // 			}
   364  // 		}
   365  // 		for idx, children := range check.children {
   366  // 			if idx >= len(query.Nodes) {
   367  // 				t.Fatalf("run sql[%+v] error, query.Nodes[%+v].NodeType not exist", sql, idx)
   368  // 			}
   369  // 			if !reflect.DeepEqual(query.Nodes[idx].Children, children) {
   370  // 				t.Fatalf("run sql[%+v] error, query.Nodes[%+v].Children should be [%+v] but now is [%+v]", sql, idx, children, query.Nodes[idx].Children)
   371  // 			}
   372  // 		}
   373  // 	}
   374  // }
   375  
   376  // test single table plan building
   377  func TestSingleTableSQLBuilder(t *testing.T) {
   378  	mock := NewMockOptimizer(false)
   379  
   380  	// should pass
   381  	sqls := []string{
   382  		"SELECT '1900-01-01 00:00:00' + INTERVAL 2147483648 SECOND",
   383  		"SELECT N_NAME, N_REGIONKEY FROM NATION WHERE N_REGIONKEY > 0 AND N_NAME LIKE '%AA' ORDER BY N_NAME DESC, N_REGIONKEY LIMIT 10, 20",
   384  		"SELECT N_NAME, N_REGIONKEY a FROM NATION WHERE N_REGIONKEY > 0 ORDER BY a DESC", //test alias
   385  		"SELECT NATION.N_NAME FROM NATION",                                       //test alias
   386  		"SELECT * FROM NATION",                                                   //test star
   387  		"SELECT a.* FROM NATION a",                                               //test star
   388  		"SELECT count(*) FROM NATION",                                            //test star
   389  		"SELECT count(*) FROM NATION group by N_NAME",                            //test star
   390  		"SELECT N_NAME, count(distinct N_REGIONKEY) FROM NATION group by N_NAME", //test distinct agg function
   391  		"SELECT N_NAME, MAX(N_REGIONKEY) FROM NATION GROUP BY N_NAME HAVING MAX(N_REGIONKEY) > 10", //test agg
   392  		"SELECT DISTINCT N_NAME FROM NATION", //test distinct
   393  		"select sum(n_nationkey) as s from nation order by s",
   394  		"select date_add(date '2001-01-01', interval 1 day) as a",
   395  		"select date_sub(date '2001-01-01', interval '1' day) as a",
   396  		"select date_add('2001-01-01', interval '1' day) as a",
   397  		"select n_name, count(*) from nation group by n_name order by 2 asc",
   398  		"select count(distinct 12)",
   399  		"select nullif(n_name, n_comment), ifnull(n_comment, n_name) from nation",
   400  
   401  		"select 18446744073709551500",
   402  		"select 0xffffffffffffffff",
   403  		"select 0xffff",
   404  
   405  		"SELECT N_REGIONKEY + 2 as a, N_REGIONKEY/2, N_REGIONKEY* N_NATIONKEY, N_REGIONKEY % N_NATIONKEY, N_REGIONKEY - N_NATIONKEY FROM NATION WHERE -N_NATIONKEY < -20", //test more expr
   406  		"SELECT N_REGIONKEY FROM NATION where N_REGIONKEY >= N_NATIONKEY or (N_NAME like '%ddd' and N_REGIONKEY >0.5)",                                                    //test more expr
   407  		"SELECT N_REGIONKEY FROM NATION where N_REGIONKEY between 2 and 2 OR N_NATIONKEY not between 3 and 10",                                                            //test more expr
   408  		// "SELECT N_REGIONKEY FROM NATION where N_REGIONKEY is null and N_NAME is not null",
   409  		"SELECT N_REGIONKEY FROM NATION where N_REGIONKEY IN (1, 2)",  //test more expr
   410  		"SELECT N_REGIONKEY FROM NATION where N_REGIONKEY NOT IN (1)", //test more expr
   411  		"select N_REGIONKEY from nation group by N_REGIONKEY having abs(nation.N_REGIONKEY - 1) >10",
   412  
   413  		"SELECT -1",
   414  		"select date_add('1997-12-31 23:59:59',INTERVAL 100000 SECOND)",
   415  		"select date_sub('1997-12-31 23:59:59',INTERVAL 2 HOUR)",
   416  		"select @str_var, @int_var, @bool_var, @float_var, @null_var",
   417  		"select @str_var, @@global.int_var, @@session.bool_var",
   418  		"select n_name from nation where n_name != @str_var and n_regionkey > @int_var",
   419  		"select n_name from nation where n_name != @@global.str_var and n_regionkey > @@session.int_var",
   420  		"select distinct(n_name), ((abs(n_regionkey))) from nation",
   421  		"SET @var = abs(-1), @@session.string_var = 'aaa'",
   422  		"SET NAMES 'utf8mb4' COLLATE 'utf8mb4_general_ci'",
   423  		"SELECT DISTINCT N_NAME FROM NATION ORDER BY N_NAME", //test distinct with order by
   424  
   425  		"prepare stmt1 from select * from nation",
   426  		"prepare stmt1 from select * from nation where n_name = ?",
   427  		"prepare stmt1 from 'select * from nation where n_name = ?'",
   428  		"prepare stmt1 from 'insert into nation select * from nation2 where n_name = ?'",
   429  		"prepare stmt1 from 'select * from nation where n_name = ?'",
   430  		"prepare stmt1 from 'drop table if exists t1'",
   431  		"prepare stmt1 from 'create table t1 (a int)'",
   432  		"prepare stmt1 from select N_REGIONKEY from nation group by N_REGIONKEY having abs(nation.N_REGIONKEY - ?) > ?",
   433  		"execute stmt1",
   434  		"execute stmt1 using @str_var, @@global.int_var",
   435  		"deallocate prepare stmt1",
   436  		"drop prepare stmt1",
   437  		"select count(n_name) from nation limit 10",
   438  		"select l_shipdate + interval '1' day from lineitem",
   439  		"select interval '1' day + l_shipdate  from lineitem",
   440  		"select interval '1' day + cast('2022-02-02 00:00:00' as datetime)",
   441  		"select cast('2022-02-02 00:00:00' as datetime) + interval '1' day",
   442  		"select true is unknown",
   443  		"select null is not unknown",
   444  		"select 1 as c,  1/2, abs(-2)",
   445  
   446  		"select date('2022-01-01'), adddate(time'00:00:00', interval 1 day), subdate(time'00:00:00', interval 1 week), '2007-01-01' + interval 1 month, '2007-01-01' -  interval 1 hour",
   447  		"select 2222332222222223333333333333333333, 0x616263,-10, bit_and(2), bit_or(2), bit_xor(10.1), 'aaa' like '%a',str_to_date('04/31/2004', '%m/%d/%Y'),unix_timestamp(from_unixtime(2147483647))",
   448  		"select max(n_nationkey) over  (partition by N_REGIONKEY) from nation",
   449  		"select * from generate_series(1, 5) g",
   450  		"select * from nation where n_name like ? or n_nationkey > 10 order by 2 limit '10'",
   451  
   452  		"values row(1,1), row(2,2), row(3,3) order by column_0 limit 2",
   453  		"select * from (values row(1,1), row(2,2), row(3,3)) a (c1, c2)",
   454  	}
   455  	runTestShouldPass(mock, t, sqls, false, false)
   456  
   457  	// should error
   458  	sqls = []string{
   459  		"SELECT N_NAME, N_REGIONKEY FROM table_not_exist",                   //table not exist
   460  		"SELECT N_NAME, column_not_exist FROM NATION",                       //column not exist
   461  		"SELECT N_NAME, N_REGIONKEY a FROM NATION ORDER BY cccc",            //column alias not exist
   462  		"SELECT N_NAME, b.N_REGIONKEY FROM NATION a ORDER BY b.N_REGIONKEY", //table alias not exist
   463  		"SELECT N_NAME FROM NATION WHERE ffff(N_REGIONKEY) > 0",             //function name not exist
   464  		"SELECT NATION.N_NAME FROM NATION a",                                // mysql should error, but i don't think it is necesssary
   465  		"select n_nationkey, sum(n_nationkey) from nation",
   466  		"SET @var = abs(a)", // can't use column
   467  		"SET @var = avg(2)", // can't use agg function
   468  
   469  		"SELECT DISTINCT N_NAME FROM NATION GROUP BY N_REGIONKEY", //test distinct with group by
   470  		"SELECT DISTINCT N_NAME FROM NATION ORDER BY N_REGIONKEY", //test distinct with order by
   471  		//"select 18446744073709551500",                             //over int64
   472  		//"select 0xffffffffffffffff",                               //over int64
   473  	}
   474  	runTestShouldError(mock, t, sqls)
   475  }
   476  
   477  // test join table plan building
   478  func TestJoinTableSqlBuilder(t *testing.T) {
   479  	mock := NewMockOptimizer(false)
   480  
   481  	// should pass
   482  	sqls := []string{
   483  		"SELECT N_NAME,N_REGIONKEY FROM NATION join REGION on NATION.N_REGIONKEY = REGION.R_REGIONKEY",
   484  		"SELECT N_NAME, N_REGIONKEY FROM NATION join REGION on NATION.N_REGIONKEY = REGION.R_REGIONKEY WHERE NATION.N_REGIONKEY > 0",
   485  		"SELECT N_NAME, NATION2.R_REGIONKEY FROM NATION2 join REGION using(R_REGIONKEY) WHERE NATION2.R_REGIONKEY > 0",
   486  		"SELECT N_NAME, NATION2.R_REGIONKEY FROM NATION2 NATURAL JOIN REGION WHERE NATION2.R_REGIONKEY > 0",
   487  		"SELECT N_NAME FROM NATION NATURAL JOIN REGION",                                                                                                     //have no same column name but it's ok
   488  		"SELECT N_NAME,N_REGIONKEY FROM NATION a join REGION b on a.N_REGIONKEY = b.R_REGIONKEY WHERE a.N_REGIONKEY > 0",                                    //test alias
   489  		"SELECT l.L_ORDERKEY a FROM CUSTOMER c, ORDERS o, LINEITEM l WHERE c.C_CUSTKEY = o.O_CUSTKEY and l.L_ORDERKEY = o.O_ORDERKEY and o.O_ORDERKEY < 10", //join three tables
   490  		"SELECT c.* FROM CUSTOMER c, ORDERS o, LINEITEM l WHERE c.C_CUSTKEY = o.O_CUSTKEY and l.L_ORDERKEY = o.O_ORDERKEY",                                  //test star
   491  		"SELECT * FROM CUSTOMER c, ORDERS o, LINEITEM l WHERE c.C_CUSTKEY = o.O_CUSTKEY and l.L_ORDERKEY = o.O_ORDERKEY",                                    //test star
   492  		"SELECT a.* FROM NATION a join REGION b on a.N_REGIONKEY = b.R_REGIONKEY WHERE a.N_REGIONKEY > 0",                                                   //test star
   493  		"SELECT * FROM NATION a join REGION b on a.N_REGIONKEY = b.R_REGIONKEY WHERE a.N_REGIONKEY > 0",
   494  		"SELECT N_NAME, R_REGIONKEY FROM NATION2 join REGION using(R_REGIONKEY)",
   495  		"select nation.n_name from nation join nation2 on nation.n_name !='a' join region on nation.n_regionkey = region.r_regionkey",
   496  		"select * from nation, nation2, region",
   497  	}
   498  	runTestShouldPass(mock, t, sqls, false, false)
   499  
   500  	// should error
   501  	sqls = []string{
   502  		"SELECT N_NAME,N_REGIONKEY FROM NATION join REGION on NATION.N_REGIONKEY = REGION.NotExistColumn",                    //column not exist
   503  		"SELECT N_NAME, R_REGIONKEY FROM NATION join REGION using(R_REGIONKEY)",                                              //column not exist
   504  		"SELECT N_NAME,N_REGIONKEY FROM NATION a join REGION b on a.N_REGIONKEY = b.R_REGIONKEY WHERE aaaaa.N_REGIONKEY > 0", //table alias not exist
   505  		"select *", //No table used
   506  	}
   507  	runTestShouldError(mock, t, sqls)
   508  }
   509  
   510  // test derived table plan building
   511  func TestDerivedTableSqlBuilder(t *testing.T) {
   512  	mock := NewMockOptimizer(false)
   513  	// should pass
   514  	sqls := []string{
   515  		"select c_custkey from (select c_custkey from CUSTOMER ) a",
   516  		"select c_custkey from (select c_custkey from CUSTOMER group by c_custkey ) a",
   517  		"select col1 from (select c_custkey from CUSTOMER group by c_custkey ) a(col1)",
   518  		"select c_custkey from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey ) a where ff > 0 order by c_custkey",
   519  		"select col1 from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey ) a(col1, col2) where col2 > 0 order by col1",
   520  		"select c_custkey from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey ) a join NATION b on a.c_custkey = b.N_REGIONKEY where b.N_NATIONKEY > 10",
   521  		"select a.* from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey ) a join NATION b on a.c_custkey = b.N_REGIONKEY where b.N_NATIONKEY > 10",
   522  		"select * from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey ) a join NATION b on a.c_custkey = b.N_REGIONKEY where b.N_NATIONKEY > 10",
   523  	}
   524  	runTestShouldPass(mock, t, sqls, false, false)
   525  
   526  	// should error
   527  	sqls = []string{
   528  		"select C_NAME from (select c_custkey from CUSTOMER) a",                               //column not exist
   529  		"select c_custkey2222 from (select c_custkey from CUSTOMER group by c_custkey ) a",    //column not exist
   530  		"select col1 from (select c_custkey from CUSTOMER group by c_custkey ) a(col1, col2)", //column length not match
   531  		"select c_custkey from (select c_custkey from CUSTOMER group by c_custkey) a(col1)",   //column not exist
   532  	}
   533  	runTestShouldError(mock, t, sqls)
   534  }
   535  
   536  // test derived table plan building
   537  func TestUnionSqlBuilder(t *testing.T) {
   538  	mock := NewMockOptimizer(false)
   539  	// should pass
   540  	sqls := []string{
   541  		"(select 1) union (select 1)",
   542  		"(((select n_nationkey from nation order by n_nationkey))) union (((select n_nationkey from nation order by n_nationkey)))",
   543  		"select 1 union select 2",
   544  		"select 1 union (select 2 union select 3)",
   545  		"(select 1 union select 2) union select 3 intersect select 4 order by 1",
   546  		"select 1 union select null",
   547  		"select n_name from nation intersect select n_name from nation2",
   548  		"select n_name from nation minus select n_name from nation2",
   549  		"select 1 union select 2 intersect select 2 union all select 1.1 minus select 22222",
   550  		"select 1 as a union select 2 order by a limit 1",
   551  		"select n_name from nation union select n_comment from nation order by n_name",
   552  		"with qn (foo, bar) as (select 1 as col, 2 as coll union select 4, 5) select qn1.bar from qn qn1",
   553  		"select n_name, n_comment from nation union all select n_name, n_comment from nation2",
   554  		"select n_name from nation intersect all select n_name from nation2",
   555  	}
   556  	runTestShouldPass(mock, t, sqls, false, false)
   557  
   558  	// should error
   559  	sqls = []string{
   560  		"select 1 union select 2, 'a'",
   561  		"select n_name as a from nation union select n_comment from nation order by n_name",
   562  		"select n_name from nation minus all select n_name from nation2", // not support
   563  	}
   564  	runTestShouldError(mock, t, sqls)
   565  }
   566  
   567  // test CTE plan building
   568  func TestCTESqlBuilder(t *testing.T) {
   569  	mock := NewMockOptimizer(false)
   570  
   571  	// should pass
   572  	sqls := []string{
   573  		"WITH qn AS (SELECT * FROM nation) SELECT * FROM qn;",
   574  		"WITH qn(a, b) AS (SELECT * FROM nation) SELECT * FROM qn;",
   575  		"with qn0 as (select 1), qn1 as (select * from qn0), qn2 as (select 1), qn3 as (select 1 from qn1, qn2) select 1 from qn3",
   576  
   577  		`WITH qn AS (select "outer" as a)
   578  		SELECT (WITH qn AS (SELECT "inner" as a) SELECT a from qn),
   579  		qn.a
   580  		FROM qn`,
   581  	}
   582  	runTestShouldPass(mock, t, sqls, false, false)
   583  
   584  	// should error
   585  	sqls = []string{
   586  		`with qn1 as (with qn3 as (select * from qn2) select * from qn3),
   587  		qn2 as (select 1)
   588  		select * from qn1`,
   589  
   590  		`WITH qn2 AS (SELECT a FROM qn WHERE a IS NULL or a>0),
   591  		qn AS (SELECT b as a FROM qn2)
   592  		SELECT qn.a  FROM qn`,
   593  	}
   594  	runTestShouldError(mock, t, sqls)
   595  }
   596  
   597  func TestInsert(t *testing.T) {
   598  	mock := NewMockOptimizer(false)
   599  	// should pass
   600  	sqls := []string{
   601  		"INSERT INTO NATION VALUES (1, 'NAME1',21, 'COMMENT1'), (2, 'NAME2', 22, 'COMMENT2')",
   602  		"INSERT INTO NATION (N_NATIONKEY, N_REGIONKEY, N_NAME, N_COMMENT) VALUES (1, 21, 'NAME1','comment1'), (2, 22, 'NAME2', 'comment2')",
   603  		"INSERT INTO NATION SELECT * FROM NATION2",
   604  	}
   605  	runTestShouldPass(mock, t, sqls, false, false)
   606  
   607  	// should error
   608  	sqls = []string{
   609  		"INSERT NATION VALUES (1, 'NAME1',21, 'COMMENT1'), ('NAME2', 22, 'COMMENT2')",                                // doesn't match value count
   610  		"INSERT NATION (N_NATIONKEY, N_REGIONKEY, N_NAME) VALUES (1, 'NAME1'), (2, 22, 'NAME2')",                     // doesn't match value count
   611  		"INSERT NATION (N_NATIONKEY, N_REGIONKEY, N_NAME2222) VALUES (1, 21, 'NAME1'), (2, 22, 'NAME2')",             // column not exist
   612  		"INSERT NATION333 (N_NATIONKEY, N_REGIONKEY, N_NAME2222) VALUES (1, 2, 'NAME1'), (2, 22, 'NAME2')",           // table not exist
   613  		"INSERT NATION (N_NATIONKEY, N_REGIONKEY, N_NAME2222) VALUES (1, 'should int32', 'NAME1'), (2, 22, 'NAME2')", // column type not match
   614  		"INSERT NATION (N_NATIONKEY, N_REGIONKEY, N_NAME2222) VALUES (1, 2.22, 'NAME1'), (2, 22, 'NAME2')",           // column type not match
   615  		"INSERT NATION (N_NATIONKEY, N_REGIONKEY, N_NAME2222) VALUES (1, 2, 'NAME1'), (2, 22, 'NAME2')",              // function expr not support now
   616  		"INSERT INTO region SELECT * FROM NATION2",                                                                   // column length not match
   617  		"INSERT INTO region SELECT 1, 2, 3, 4, 5, 6 FROM NATION2",                                                    // column length not match
   618  		"INSERT NATION333 (N_NATIONKEY, N_REGIONKEY, N_NAME2222) SELECT 1, 2, 3 FROM NATION2",                        // table not exist
   619  	}
   620  	runTestShouldError(mock, t, sqls)
   621  }
   622  
   623  func TestUpdate(t *testing.T) {
   624  	mock := NewMockOptimizer(true)
   625  	// should pass
   626  	sqls := []string{
   627  		"UPDATE NATION SET N_NAME ='U1', N_REGIONKEY=2",
   628  		"UPDATE NATION SET N_NAME ='U1', N_REGIONKEY=2 WHERE N_NATIONKEY > 10 LIMIT 20",
   629  		"UPDATE NATION SET N_NAME ='U1', N_REGIONKEY=N_REGIONKEY+2 WHERE N_NATIONKEY > 10 LIMIT 20",
   630  		"update NATION a join NATION2 b on a.N_REGIONKEY = b.R_REGIONKEY set a.N_NAME = 'aa'",
   631  		"prepare stmt1 from 'update nation set n_name = ? where n_nationkey > ?'",
   632  		"drop index idx1 on test_idx",
   633  	}
   634  	runTestShouldPass(mock, t, sqls, false, false)
   635  
   636  	// should error
   637  	sqls = []string{
   638  		"UPDATE NATION SET N_NAME2 ='U1', N_REGIONKEY=2",    // column not exist
   639  		"UPDATE NATION2222 SET N_NAME ='U1', N_REGIONKEY=2", // table not exist
   640  	}
   641  	runTestShouldError(mock, t, sqls)
   642  }
   643  
   644  func TestDelete(t *testing.T) {
   645  	mock := NewMockOptimizer(true)
   646  	// should pass
   647  	sqls := []string{
   648  		"DELETE FROM NATION",
   649  		"DELETE FROM NATION WHERE N_NATIONKEY > 10",
   650  		"DELETE FROM NATION WHERE N_NATIONKEY > 10 LIMIT 20",
   651  		"delete nation from nation left join nation2 on nation.n_nationkey = nation2.n_nationkey",
   652  		"delete from nation",
   653  		"delete nation, nation2 from nation join nation2 on nation.n_name = nation2.n_name",
   654  		"prepare stmt1 from 'delete from nation where n_nationkey > ?'",
   655  	}
   656  	runTestShouldPass(mock, t, sqls, false, false)
   657  
   658  	// should error
   659  	sqls = []string{
   660  		"DELETE FROM NATION2222",                     // table not exist
   661  		"DELETE FROM NATION WHERE N_NATIONKEY2 > 10", // column not found
   662  	}
   663  	runTestShouldError(mock, t, sqls)
   664  }
   665  
   666  func TestSubQuery(t *testing.T) {
   667  	mock := NewMockOptimizer(false)
   668  	// should pass
   669  	sqls := []string{
   670  		"SELECT * FROM NATION where N_REGIONKEY > (select max(R_REGIONKEY) from REGION)",                                 // unrelated
   671  		"SELECT * FROM NATION where N_REGIONKEY in (select max(R_REGIONKEY) from REGION)",                                // unrelated
   672  		"SELECT * FROM NATION where N_REGIONKEY not in (select max(R_REGIONKEY) from REGION)",                            // unrelated
   673  		"SELECT * FROM NATION where exists (select max(R_REGIONKEY) from REGION)",                                        // unrelated
   674  		"SELECT * FROM NATION where N_REGIONKEY > (select max(R_REGIONKEY) from REGION where R_REGIONKEY = N_REGIONKEY)", // related
   675  		//"DELETE FROM NATION WHERE N_NATIONKEY > 10",
   676  		`select
   677  		sum(l_extendedprice) / 7.0 as avg_yearly
   678  	from
   679  		lineitem,
   680  		part
   681  	where
   682  		p_partkey = l_partkey
   683  		and p_brand = 'Brand#54'
   684  		and p_container = 'LG BAG'
   685  		and l_quantity < (
   686  			select
   687  				0.2 * avg(l_quantity)
   688  			from
   689  				lineitem
   690  			where
   691  				l_partkey = p_partkey
   692  		);`, //tpch q17
   693  		"select * from nation where n_regionkey in (select r_regionkey from region) and n_nationkey not in (1,2) and n_nationkey = some (select n_nationkey from nation2)",
   694  	}
   695  	runTestShouldPass(mock, t, sqls, false, false)
   696  
   697  	// should error
   698  	sqls = []string{
   699  		"SELECT * FROM NATION where N_REGIONKEY > (select max(R_REGIONKEY) from REGION222)",                                 // table not exist
   700  		"SELECT * FROM NATION where N_REGIONKEY > (select max(R_REGIONKEY) from REGION where R_REGIONKEY < N_REGIONKEY222)", // column not exist
   701  		"SELECT * FROM NATION where N_REGIONKEY > (select max(R_REGIONKEY) from REGION where R_REGIONKEY < N_REGIONKEY)",    // related
   702  	}
   703  	runTestShouldError(mock, t, sqls)
   704  }
   705  
   706  func TestMysqlCompatibilityMode(t *testing.T) {
   707  	mock := NewMockOptimizer(false)
   708  
   709  	sqls := []string{
   710  		"SELECT n_nationkey FROM NATION group by n_name",
   711  		"SELECT n_nationkey, min(n_name) FROM NATION",
   712  		"SELECT n_nationkey + 100 FROM NATION group by n_name",
   713  	}
   714  	// withou mysql compatibility
   715  	runTestShouldError(mock, t, sqls)
   716  	// with mysql compatibility
   717  	mock.ctxt.mysqlCompatible = true
   718  	runTestShouldPass(mock, t, sqls, false, false)
   719  }
   720  
   721  func TestTcl(t *testing.T) {
   722  	mock := NewMockOptimizer(false)
   723  	// should pass
   724  	sqls := []string{
   725  		"start transaction",
   726  		"start transaction read write",
   727  		"begin",
   728  		"commit and chain",
   729  		"commit and chain no release",
   730  		"rollback and chain",
   731  	}
   732  	runTestShouldPass(mock, t, sqls, false, false)
   733  
   734  	// should error
   735  	sqls = []string{}
   736  	runTestShouldError(mock, t, sqls)
   737  }
   738  
   739  func TestDdl(t *testing.T) {
   740  	mock := NewMockOptimizer(false)
   741  	// should pass
   742  	sqls := []string{
   743  		"create database db_name",               //db not exists and pass
   744  		"create database if not exists db_name", //db not exists but pass
   745  		"create database if not exists tpch",    //db exists and pass
   746  		"drop database if exists db_name",       //db not exists but pass
   747  		"drop database tpch",                    //db exists, pass
   748  		"create view v1 as select * from nation",
   749  
   750  		"create table tbl_name (t bool(20) comment 'dd', b int unsigned, c char(20), d varchar(20), primary key(b), index idx_t(c)) comment 'test comment'",
   751  		"create table if not exists tbl_name (b int default 20 primary key, c char(20) default 'ss', d varchar(20) default 'kkk')",
   752  		"create table if not exists nation (t bool(20), b int, c char(20), d varchar(20))",
   753  		"drop table if exists tbl_name",
   754  		"drop table if exists nation",
   755  		"drop table nation",
   756  		"drop table tpch.nation",
   757  		"drop table if exists tpch.tbl_not_exist",
   758  		"drop table if exists db_not_exist.tbl",
   759  		"drop view v1",
   760  		"truncate nation",
   761  		"truncate tpch.nation",
   762  		"truncate table nation",
   763  		"truncate table tpch.nation",
   764  		"create unique index idx_name on nation(n_regionkey)",
   765  	}
   766  	runTestShouldPass(mock, t, sqls, false, false)
   767  
   768  	// should error
   769  	sqls = []string{
   770  		// "create database tpch",  // check in pipeline now
   771  		// "drop database db_name", // check in pipeline now
   772  		// "create table nation (t bool(20), b int, c char(20), d varchar(20))",             // check in pipeline now
   773  		"create table nation (b int primary key, c char(20) primary key, d varchar(20))", //Multiple primary key
   774  		"drop table tbl_name",           //table not exists in tpch
   775  		"drop table tpch.tbl_not_exist", //database not exists
   776  		"drop table db_not_exist.tbl",   //table not exists
   777  	}
   778  	runTestShouldError(mock, t, sqls)
   779  }
   780  
   781  func TestShow(t *testing.T) {
   782  	mock := NewMockOptimizer(false)
   783  	// should pass
   784  	sqls := []string{
   785  		"show variables",
   786  		//"show create database tpch",
   787  		"show create table nation",
   788  		"show create table tpch.nation",
   789  		"show databases",
   790  		"show databases like '%d'",
   791  		"show databases where `Database` = '11'",
   792  		"show databases where `Database` = '11' or `Database` = 'ddd'",
   793  		"show tables",
   794  		"show tables from tpch",
   795  		"show tables like '%dd'",
   796  		"show tables from tpch where `Tables_in_tpch` = 'aa' or `Tables_in_tpch` like '%dd'",
   797  		"show columns from nation",
   798  		"show columns from nation from tpch",
   799  		"show columns from nation where `Field` like '%ff' or `Type` = 1 or `Null` = 0",
   800  		"show create view v1",
   801  		"show create table v1",
   802  		"show table_number",
   803  		"show table_number from tpch",
   804  		"show column_number from nation",
   805  		"show config",
   806  		"show index from tpch.nation",
   807  		"show locks",
   808  		"show node list",
   809  		"show grants for ROLE role1",
   810  		"show function status",
   811  		"show function status like '%ff'",
   812  		// "show grants",
   813  	}
   814  	runTestShouldPass(mock, t, sqls, false, false)
   815  
   816  	// should error
   817  	sqls = []string{
   818  		"show create database db_not_exist",                    //db no exist
   819  		"show create table tpch.nation22",                      //table not exist
   820  		"show create view vvv",                                 //view not exist
   821  		"show databases where d ='a'",                          //Column not exist,  show databases only have one column named 'Database'
   822  		"show databases where `Databaseddddd` = '11'",          //column not exist
   823  		"show tables from tpch22222",                           //database not exist
   824  		"show tables from tpch where Tables_in_tpch222 = 'aa'", //column not exist
   825  		"show columns from nation_ddddd",                       //table not exist
   826  		"show columns from nation_ddddd from tpch",             //table not exist
   827  		"show columns from nation where `Field22` like '%ff'",  //column not exist
   828  		"show index from tpch.dddd",
   829  		"show table_number from tpch222",
   830  		"show column_number from nation222",
   831  	}
   832  	runTestShouldError(mock, t, sqls)
   833  }
   834  
   835  func TestResultColumns(t *testing.T) {
   836  	mock := NewMockOptimizer(false)
   837  	getColumns := func(sql string) []*ColDef {
   838  		logicPlan, err := runOneStmt(mock, t, sql)
   839  		if err != nil {
   840  			t.Fatalf("sql %s build plan error:%+v", sql, err)
   841  		}
   842  		return GetResultColumnsFromPlan(logicPlan)
   843  	}
   844  
   845  	returnNilSQL := []string{
   846  		"begin",
   847  		"commit",
   848  		"rollback",
   849  		"INSERT NATION VALUES (1, 'NAME1',21, 'COMMENT1'), (2, 'NAME2', 22, 'COMMENT2')",
   850  		// "UPDATE NATION SET N_NAME ='U1', N_REGIONKEY=2",
   851  		// "DELETE FROM NATION",
   852  		"create database db_name",
   853  		"drop database tpch",
   854  		"create table tbl_name (b int unsigned, c char(20))",
   855  		"drop table nation",
   856  	}
   857  	for _, sql := range returnNilSQL {
   858  		columns := getColumns(sql)
   859  		if columns != nil {
   860  			t.Fatalf("sql:%+v, return columns should be nil", sql)
   861  		}
   862  	}
   863  
   864  	returnColumnsSQL := map[string]string{
   865  		"SELECT N_NAME, N_REGIONKEY a FROM NATION WHERE N_REGIONKEY > 0 ORDER BY a DESC":            "N_NAME,a",
   866  		"select n_nationkey, sum(n_regionkey) from (select * from nation) sub group by n_nationkey": "n_nationkey,sum(n_regionkey)",
   867  		"show variables":            "Variable_name,Value",
   868  		"show create database tpch": "Database,Create Database",
   869  		"show create table nation":  "Table,Create Table",
   870  		"show databases":            "Database",
   871  		"show tables":               "Tables_in_tpch",
   872  		"show columns from nation":  "Field,Type,Null,Key,Default,Extra,Comment",
   873  	}
   874  	for sql, colsStr := range returnColumnsSQL {
   875  		cols := strings.Split(colsStr, ",")
   876  		columns := getColumns(sql)
   877  		if len(columns) != len(cols) {
   878  			t.Fatalf("sql:%+v, return columns should be [%s]", sql, colsStr)
   879  		}
   880  		for idx, col := range cols {
   881  			// now ast always change col_name to lower string. will be fixed soon
   882  			if !strings.EqualFold(columns[idx].Name, col) {
   883  				t.Fatalf("sql:%+v, return columns should be [%s]", sql, colsStr)
   884  			}
   885  		}
   886  	}
   887  }
   888  
   889  func TestBuildUnnest(t *testing.T) {
   890  	mock := NewMockOptimizer(false)
   891  	sqls := []string{
   892  		`select * from unnest('{"a":1}') as f`,
   893  		`select * from unnest('{"a":1}', '') as f`,
   894  		`select * from unnest('{"a":1}', '$', true) as f`,
   895  	}
   896  	runTestShouldPass(mock, t, sqls, false, false)
   897  	errSqls := []string{
   898  		`select * from unnest(t.t1.a)`,
   899  		`select * from unnest(t.a, "$.b")`,
   900  		`select * from unnest(t.a, "$.b", true)`,
   901  		`select * from unnest(t.a) as f`,
   902  		`select * from unnest(t.a, "$.b") as f`,
   903  		`select * from unnest(t.a, "$.b", true) as f`,
   904  		`select * from unnest('{"a":1}')`,
   905  		`select * from unnest('{"a":1}', "$")`,
   906  		`select * from unnest('{"a":1}', "", true)`,
   907  	}
   908  	runTestShouldError(mock, t, errSqls)
   909  }
   910  
   911  func TestVisitRule(t *testing.T) {
   912  	sql := "select * from nation where n_nationkey > ? or n_nationkey=@int_var or abs(-1) > 1"
   913  	mock := NewMockOptimizer(false)
   914  	ctx := context.TODO()
   915  	plan, err := runOneStmt(mock, t, sql)
   916  	if err != nil {
   917  		t.Fatalf("should not error, sql=%s", sql)
   918  	}
   919  	getParamRule := NewGetParamRule()
   920  	vp := NewVisitPlan(plan, []VisitPlanRule{getParamRule})
   921  	err = vp.Visit(context.TODO())
   922  	if err != nil {
   923  		t.Fatalf("should not error, sql=%s", sql)
   924  	}
   925  	getParamRule.SetParamOrder()
   926  	args := getParamRule.params
   927  
   928  	resetParamOrderRule := NewResetParamOrderRule(args)
   929  	vp = NewVisitPlan(plan, []VisitPlanRule{resetParamOrderRule})
   930  	err = vp.Visit(ctx)
   931  	if err != nil {
   932  		t.Fatalf("should not error, sql=%s", sql)
   933  	}
   934  
   935  	params := []*Expr{
   936  		makePlan2Int64ConstExprWithType(10),
   937  	}
   938  	resetParamRule := NewResetParamRefRule(ctx, params)
   939  	resetVarRule := NewResetVarRefRule(&mock.ctxt, &process.Process{})
   940  	constantFoldRule := NewConstantFoldRule(&mock.ctxt)
   941  	vp = NewVisitPlan(plan, []VisitPlanRule{resetParamRule, resetVarRule, constantFoldRule})
   942  	err = vp.Visit(ctx)
   943  	if err != nil {
   944  		t.Fatalf("should not error, sql=%s", sql)
   945  	}
   946  }
   947  
   948  func getJSON(v any, t *testing.T) []byte {
   949  	b, err := json.Marshal(v)
   950  	if err != nil {
   951  		t.Logf("%+v", v)
   952  	}
   953  	var out bytes.Buffer
   954  	err = json.Indent(&out, b, "", "  ")
   955  	if err != nil {
   956  		t.Logf("%+v", v)
   957  	}
   958  	return out.Bytes()
   959  }
   960  
   961  func testDeepCopy(logicPlan *Plan) {
   962  	switch logicPlan.Plan.(type) {
   963  	case *plan.Plan_Query:
   964  		_ = DeepCopyPlan(logicPlan)
   965  	case *plan.Plan_Ddl:
   966  		_ = DeepCopyPlan(logicPlan)
   967  	case *plan.Plan_Dcl:
   968  	}
   969  }
   970  
   971  func outPutPlan(logicPlan *Plan, toFile bool, t *testing.T) {
   972  	var json []byte
   973  	switch logicPlan.Plan.(type) {
   974  	case *plan.Plan_Query:
   975  		json = getJSON(logicPlan.GetQuery(), t)
   976  	case *plan.Plan_Tcl:
   977  		json = getJSON(logicPlan.GetTcl(), t)
   978  	case *plan.Plan_Ddl:
   979  		json = getJSON(logicPlan.GetDdl(), t)
   980  	case *plan.Plan_Dcl:
   981  		json = getJSON(logicPlan.GetDcl(), t)
   982  	}
   983  	if toFile {
   984  		err := os.WriteFile("/tmp/mo_plan_test.json", json, 0777)
   985  		if err != nil {
   986  			t.Logf("%+v", err)
   987  		}
   988  	} else {
   989  		t.Logf(string(json))
   990  	}
   991  }
   992  
   993  func runOneStmt(opt Optimizer, t *testing.T, sql string) (*Plan, error) {
   994  	stmts, err := mysql.Parse(opt.CurrentContext().GetContext(), sql)
   995  	if err != nil {
   996  		t.Fatalf("%+v", err)
   997  	}
   998  	// this sql always return one stmt
   999  	ctx := opt.CurrentContext()
  1000  	return BuildPlan(ctx, stmts[0])
  1001  }
  1002  
  1003  func runTestShouldPass(opt Optimizer, t *testing.T, sqls []string, printJSON bool, toFile bool) {
  1004  	for _, sql := range sqls {
  1005  		logicPlan, err := runOneStmt(opt, t, sql)
  1006  		if err != nil {
  1007  			t.Fatalf("%+v, sql=%v", err, sql)
  1008  		}
  1009  		testDeepCopy(logicPlan)
  1010  		if printJSON {
  1011  			outPutPlan(logicPlan, toFile, t)
  1012  		}
  1013  	}
  1014  }
  1015  
  1016  func runTestShouldError(opt Optimizer, t *testing.T, sqls []string) {
  1017  	for _, sql := range sqls {
  1018  		_, err := runOneStmt(opt, t, sql)
  1019  		if err == nil {
  1020  			t.Fatalf("should error, but pass: %v", sql)
  1021  		}
  1022  	}
  1023  }
  1024  
  1025  func Test_mergeContexts(t *testing.T) {
  1026  	b1 := NewBinding(0, 1, "a", nil, nil, false)
  1027  	bc1 := NewBindContext(nil, nil)
  1028  	bc1.bindings = append(bc1.bindings, b1)
  1029  
  1030  	b2 := NewBinding(1, 2, "a", nil, nil, false)
  1031  	bc2 := NewBindContext(nil, nil)
  1032  	bc2.bindings = append(bc2.bindings, b2)
  1033  
  1034  	bc := NewBindContext(nil, nil)
  1035  
  1036  	//a merge a
  1037  	err := bc.mergeContexts(context.Background(), bc1, bc2)
  1038  	assert.Error(t, err)
  1039  	assert.EqualError(t, err, "invalid input: table 'a' specified more than once")
  1040  
  1041  	//a merge b
  1042  	b3 := NewBinding(2, 3, "b", nil, nil, false)
  1043  	bc3 := NewBindContext(nil, nil)
  1044  	bc3.bindings = append(bc3.bindings, b3)
  1045  
  1046  	err = bc.mergeContexts(context.Background(), bc1, bc3)
  1047  	assert.NoError(t, err)
  1048  
  1049  	// a merge a, ctx is  nil
  1050  	var ctx context.Context
  1051  	err = bc.mergeContexts(ctx, bc1, bc2)
  1052  	assert.Error(t, err)
  1053  	assert.EqualError(t, err, "invalid input: table 'a' specified more than once")
  1054  }