github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/plan/explain/column_prune_test.go (about)

     1  // Copyright 2021 - 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package explain
    16  
    17  import (
    18  	"context"
    19  	"strings"
    20  	"testing"
    21  
    22  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    23  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    24  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect/mysql"
    25  	plan2 "github.com/matrixorigin/matrixone/pkg/sql/plan"
    26  	"github.com/stretchr/testify/require"
    27  )
    28  
    29  func TestSingleTableQueryPrune(t *testing.T) {
    30  	cases := []struct {
    31  		name         string
    32  		sql          string
    33  		wantTableCol []Entry[string, []string]
    34  	}{
    35  		{
    36  			name: "Test01",
    37  			sql:  "SELECT N_NAME, N_REGIONKEY FROM NATION WHERE N_REGIONKEY > 0 AND N_NAME LIKE '%AA' ORDER BY N_NAME DESC, N_REGIONKEY LIMIT 10, 20",
    38  			wantTableCol: []Entry[string, []string]{
    39  				{
    40  					tableName: "nation",
    41  					colNames:  []string{"n_name", "n_regionkey"},
    42  				},
    43  			},
    44  		},
    45  
    46  		{
    47  			name: "Test02",
    48  			sql:  "SELECT N_NAME, N_REGIONKEY, 23 as a FROM NATION",
    49  			wantTableCol: []Entry[string, []string]{
    50  				{
    51  					tableName: "nation",
    52  					colNames:  []string{"n_name", "n_regionkey"},
    53  				},
    54  			},
    55  		},
    56  
    57  		{
    58  			name: "Test03",
    59  			sql:  "SELECT N_NAME, N_REGIONKEY a FROM NATION WHERE N_REGIONKEY > 0 ORDER BY a DESC",
    60  			wantTableCol: []Entry[string, []string]{
    61  				{
    62  					tableName: "nation",
    63  					colNames:  []string{"n_name", "n_regionkey"},
    64  				},
    65  			},
    66  		},
    67  
    68  		{
    69  			name: "Test04",
    70  			sql:  "SELECT NATION.N_NAME FROM NATION",
    71  			wantTableCol: []Entry[string, []string]{
    72  				{
    73  					tableName: "nation",
    74  					colNames:  []string{"n_name"},
    75  				},
    76  			},
    77  		},
    78  
    79  		{
    80  			name: "Test05",
    81  			sql:  "SELECT a.* FROM NATION a",
    82  			wantTableCol: []Entry[string, []string]{
    83  				{
    84  					tableName: "nation",
    85  					colNames:  []string{"n_nationkey", "n_name", "n_regionkey", "n_comment"},
    86  				},
    87  			},
    88  		},
    89  
    90  		{
    91  			name: "Test06",
    92  			sql:  "SELECT count(*) FROM NATION",
    93  			wantTableCol: []Entry[string, []string]{
    94  				{
    95  					tableName: "nation",
    96  					colNames:  []string{"n_nationkey"},
    97  				},
    98  			},
    99  		},
   100  
   101  		{
   102  			name: "Test07",
   103  			sql:  "SELECT count(*) FROM NATION group by N_NAME",
   104  			wantTableCol: []Entry[string, []string]{
   105  				{
   106  					tableName: "nation",
   107  					colNames:  []string{"n_name"},
   108  				},
   109  			},
   110  		},
   111  
   112  		{
   113  			name: "Test08",
   114  			sql:  "SELECT N_NAME, MAX(N_REGIONKEY) FROM NATION GROUP BY N_NAME HAVING MAX(N_REGIONKEY) > 10",
   115  			wantTableCol: []Entry[string, []string]{
   116  				{
   117  					tableName: "nation",
   118  					colNames:  []string{"n_name", "n_regionkey"},
   119  				},
   120  			},
   121  		},
   122  
   123  		{
   124  			name: "Test09",
   125  			sql:  "SELECT DISTINCT N_NAME FROM NATION limit 10",
   126  			wantTableCol: []Entry[string, []string]{
   127  				{
   128  					tableName: "nation",
   129  					colNames:  []string{"n_name"},
   130  				},
   131  			},
   132  		},
   133  
   134  		{
   135  			name: "Test10",
   136  			sql:  "SELECT DISTINCT N_NAME FROM NATION",
   137  			wantTableCol: []Entry[string, []string]{
   138  				{
   139  					tableName: "nation",
   140  					colNames:  []string{"n_name"},
   141  				},
   142  			},
   143  		},
   144  
   145  		{
   146  			name: "Test11",
   147  			sql:  "SELECT N_REGIONKEY + 2 as a, N_REGIONKEY/2, N_REGIONKEY* N_NATIONKEY, N_REGIONKEY % N_NATIONKEY, N_REGIONKEY - N_NATIONKEY FROM NATION WHERE -N_NATIONKEY < -20",
   148  			wantTableCol: []Entry[string, []string]{
   149  				{
   150  					tableName: "nation",
   151  					colNames:  []string{"n_nationkey", "n_regionkey"},
   152  				},
   153  			},
   154  		},
   155  
   156  		{
   157  			name: "Test12",
   158  			sql:  "SELECT N_REGIONKEY FROM NATION where N_REGIONKEY >= N_NATIONKEY or (N_NAME like '%ddd' and N_REGIONKEY >0.5)",
   159  			wantTableCol: []Entry[string, []string]{
   160  				{
   161  					tableName: "nation",
   162  					colNames:  []string{"n_nationkey", "n_name", "n_regionkey"},
   163  				},
   164  			},
   165  		},
   166  
   167  		{
   168  			name: "Test13",
   169  			sql:  "SELECT N_REGIONKEY FROM NATION where N_REGIONKEY between 2 and 2 OR N_NATIONKEY not between 3 and 10",
   170  			wantTableCol: []Entry[string, []string]{
   171  				{
   172  					tableName: "nation",
   173  					colNames:  []string{"n_nationkey", "n_regionkey"},
   174  				},
   175  			},
   176  		},
   177  	}
   178  
   179  	for _, c := range cases {
   180  		t.Run(c.name, func(t *testing.T) {
   181  			mock := plan2.NewMockOptimizer(false)
   182  			logicPlan, err := buildOneStmt(mock, t, c.sql)
   183  			if err != nil {
   184  				t.Fatalf("%+v", err)
   185  			}
   186  			columns, err := getPrunedTableColumns(logicPlan)
   187  			if err != nil {
   188  				t.Fatalf("%+v", err)
   189  			}
   190  			require.Equal(t, c.wantTableCol, columns)
   191  		})
   192  	}
   193  }
   194  
   195  func TestJoinQueryPrune(t *testing.T) {
   196  	cases := []struct {
   197  		name         string
   198  		sql          string
   199  		wantTableCol []Entry[string, []string]
   200  	}{
   201  		{
   202  			name: "Test01",
   203  			sql:  "SELECT NATION.N_NAME, REGION.R_NAME FROM NATION join REGION on NATION.N_REGIONKEY = REGION.R_REGIONKEY WHERE NATION.N_REGIONKEY > 10 AND NATION.N_NAME > REGION.R_NAME",
   204  			wantTableCol: []Entry[string, []string]{
   205  				{
   206  					tableName: "nation",
   207  					colNames:  []string{"n_name", "n_regionkey"},
   208  				},
   209  				{
   210  					tableName: "region",
   211  					colNames:  []string{"r_regionkey", "r_name"},
   212  				},
   213  			},
   214  		},
   215  		{
   216  			name: "Test02",
   217  			sql:  "SELECT NATION.N_NAME, REGION.R_NAME FROM NATION left join REGION on NATION.N_REGIONKEY = REGION.R_REGIONKEY WHERE NATION.N_REGIONKEY > 10 AND NATION.N_NAME > REGION.R_NAME",
   218  			wantTableCol: []Entry[string, []string]{
   219  				{
   220  					tableName: "nation",
   221  					colNames:  []string{"n_name", "n_regionkey"},
   222  				},
   223  				{
   224  					tableName: "region",
   225  					colNames:  []string{"r_regionkey", "r_name"},
   226  				},
   227  			},
   228  		},
   229  		{
   230  			name: "Test03",
   231  			sql:  "SELECT N_NAME, N_REGIONKEY FROM NATION join REGION on NATION.N_REGIONKEY = REGION.R_REGIONKEY WHERE NATION.N_REGIONKEY > 0",
   232  			wantTableCol: []Entry[string, []string]{
   233  				{
   234  					tableName: "nation",
   235  					colNames:  []string{"n_name", "n_regionkey"},
   236  				},
   237  				{
   238  					tableName: "region",
   239  					colNames:  []string{"r_regionkey"},
   240  				},
   241  			},
   242  		},
   243  		{
   244  			name: "Test04",
   245  			sql:  "SELECT N_NAME, NATION2.R_REGIONKEY FROM NATION2 join REGION using(R_REGIONKEY) WHERE NATION2.R_REGIONKEY > 0",
   246  			wantTableCol: []Entry[string, []string]{
   247  				{
   248  					tableName: "nation2",
   249  					colNames:  []string{"n_name", "r_regionkey"},
   250  				},
   251  				{
   252  					tableName: "region",
   253  					colNames:  []string{"r_regionkey"},
   254  				},
   255  			},
   256  		},
   257  		{
   258  			name: "Test05",
   259  			sql:  "SELECT N_NAME, NATION2.R_REGIONKEY FROM NATION2 NATURAL JOIN REGION WHERE NATION2.R_REGIONKEY > 0",
   260  			wantTableCol: []Entry[string, []string]{
   261  				{
   262  					tableName: "nation2",
   263  					colNames:  []string{"n_name", "r_regionkey"},
   264  				},
   265  				{
   266  					tableName: "region",
   267  					colNames:  []string{"r_regionkey"},
   268  				},
   269  			},
   270  		},
   271  		{
   272  			name: "Test06",
   273  			sql:  "SELECT N_NAME FROM NATION NATURAL JOIN REGION",
   274  			wantTableCol: []Entry[string, []string]{
   275  				{
   276  					tableName: "nation",
   277  					colNames:  []string{"n_name"},
   278  				},
   279  				{
   280  					tableName: "region",
   281  					colNames:  []string{"r_regionkey"},
   282  				},
   283  			},
   284  		},
   285  		{
   286  			name: "Test07",
   287  			sql:  "SELECT N_NAME,N_REGIONKEY FROM NATION a join REGION b on a.N_REGIONKEY = b.R_REGIONKEY WHERE a.N_REGIONKEY > 0",
   288  			wantTableCol: []Entry[string, []string]{
   289  				{
   290  					tableName: "nation",
   291  					colNames:  []string{"n_name", "n_regionkey"},
   292  				},
   293  				{
   294  					tableName: "region",
   295  					colNames:  []string{"r_regionkey"},
   296  				},
   297  			},
   298  		},
   299  		{
   300  			name: "Test08",
   301  			sql:  "SELECT l.L_ORDERKEY a FROM CUSTOMER c, ORDERS o, LINEITEM l WHERE c.C_CUSTKEY = o.O_CUSTKEY and l.L_ORDERKEY = o.O_ORDERKEY and o.O_ORDERKEY < 10",
   302  			wantTableCol: []Entry[string, []string]{
   303  				{
   304  					tableName: "customer",
   305  					colNames:  []string{"c_custkey"},
   306  				},
   307  				{
   308  					tableName: "orders",
   309  					colNames:  []string{"o_orderkey", "o_custkey"},
   310  				},
   311  				{
   312  					tableName: "lineitem",
   313  					colNames:  []string{"l_orderkey"},
   314  				},
   315  			},
   316  		},
   317  		{
   318  			name: "Test09",
   319  			sql:  "SELECT c.* FROM CUSTOMER c, ORDERS o, LINEITEM l WHERE c.C_CUSTKEY = o.O_CUSTKEY and l.L_ORDERKEY = o.O_ORDERKEY",
   320  			wantTableCol: []Entry[string, []string]{
   321  				{
   322  					tableName: "customer",
   323  					colNames:  []string{"c_custkey", "c_name", "c_address", "c_nationkey", "c_phone", "c_acctbal", "c_mktsegment", "c_comment"},
   324  				},
   325  				{
   326  					tableName: "orders",
   327  					colNames:  []string{"o_orderkey", "o_custkey"},
   328  				},
   329  				{
   330  					tableName: "lineitem",
   331  					colNames:  []string{"l_orderkey"},
   332  				},
   333  			},
   334  		},
   335  		{
   336  			name: "Test10",
   337  			sql:  "SELECT * FROM CUSTOMER c, ORDERS o, LINEITEM l WHERE c.C_CUSTKEY = o.O_CUSTKEY and l.L_ORDERKEY = o.O_ORDERKEY",
   338  			wantTableCol: []Entry[string, []string]{
   339  				{
   340  					tableName: "customer",
   341  					colNames:  []string{"c_custkey", "c_name", "c_address", "c_nationkey", "c_phone", "c_acctbal", "c_mktsegment", "c_comment"},
   342  				},
   343  				{
   344  					tableName: "orders",
   345  					colNames:  []string{"o_orderkey", "o_custkey", "o_orderstatus", "o_totalprice", "o_orderdate", "o_orderpriority", "o_clerk", "o_shippriority", "o_comment"},
   346  				},
   347  				{
   348  					tableName: "lineitem",
   349  					colNames:  []string{"l_orderkey", "l_partkey", "l_suppkey", "l_linenumber", "l_quantity", "l_extendedprice", "l_discount", "l_tax", "l_returnflag", "l_linestatus", "l_shipdate", "l_commitdate", "l_receiptdate", "l_shipinstruct", "l_shipmode", "l_comment"},
   350  				},
   351  			},
   352  		},
   353  		{
   354  			name: "Test11",
   355  			sql:  "SELECT a.* FROM NATION a join REGION b on a.N_REGIONKEY = b.R_REGIONKEY WHERE a.N_REGIONKEY > 0",
   356  			wantTableCol: []Entry[string, []string]{
   357  				{
   358  					tableName: "nation",
   359  					colNames:  []string{"n_nationkey", "n_name", "n_regionkey", "n_comment"},
   360  				},
   361  				{
   362  					tableName: "region",
   363  					colNames:  []string{"r_regionkey"},
   364  				},
   365  			},
   366  		},
   367  		{
   368  			name: "Test12",
   369  			sql:  "SELECT a.* FROM NATION a join REGION b on a.N_REGIONKEY = b.R_REGIONKEY WHERE a.N_REGIONKEY > 0",
   370  			wantTableCol: []Entry[string, []string]{
   371  				{
   372  					tableName: "nation",
   373  					colNames:  []string{"n_nationkey", "n_name", "n_regionkey", "n_comment"},
   374  				},
   375  				{
   376  					tableName: "region",
   377  					colNames:  []string{"r_regionkey"},
   378  				},
   379  			},
   380  		},
   381  	}
   382  
   383  	for _, c := range cases {
   384  		t.Run(c.name, func(t *testing.T) {
   385  			mock := plan2.NewMockOptimizer(false)
   386  			logicPlan, err := buildOneStmt(mock, t, c.sql)
   387  			if err != nil {
   388  				t.Fatalf("%+v", err)
   389  			}
   390  			columns, err := getPrunedTableColumns(logicPlan)
   391  			if err != nil {
   392  				t.Fatalf("%+v", err)
   393  			}
   394  			require.Equal(t, c.wantTableCol, columns)
   395  		})
   396  	}
   397  }
   398  
   399  func TestNestedQueryPrune(t *testing.T) {
   400  
   401  	cases := []struct {
   402  		name         string
   403  		sql          string
   404  		wantTableCol []Entry[string, []string]
   405  	}{
   406  		{
   407  			name: "Test01",
   408  			sql:  "SELECT * FROM NATION where N_REGIONKEY > (select max(R_REGIONKEY) from REGION)",
   409  			wantTableCol: []Entry[string, []string]{
   410  				{
   411  					tableName: "nation",
   412  					colNames:  []string{"n_nationkey", "n_name", "n_regionkey", "n_comment"},
   413  				},
   414  				{
   415  					tableName: "region",
   416  					colNames:  []string{"r_regionkey"},
   417  				},
   418  			},
   419  		},
   420  		{
   421  			name: "Test02",
   422  			sql:  "SELECT * FROM NATION where N_REGIONKEY > (select max(R_REGIONKEY) from REGION where R_REGIONKEY = N_REGIONKEY)",
   423  			wantTableCol: []Entry[string, []string]{
   424  				{
   425  					tableName: "nation",
   426  					colNames:  []string{"n_nationkey", "n_name", "n_regionkey", "n_comment"},
   427  				},
   428  				{
   429  					tableName: "region",
   430  					colNames:  []string{"r_regionkey"},
   431  				},
   432  			},
   433  		},
   434  		{
   435  			name: "Test03",
   436  			sql:  "select sum(l_extendedprice) / 7.0 as avg_yearly from lineitem, part where p_partkey = l_partkey and p_brand = 'Brand#54' and p_container = 'LG BAG' and l_quantity < (select 0.2 * avg(l_quantity) from lineitem where l_partkey = p_partkey)",
   437  			wantTableCol: []Entry[string, []string]{
   438  				{
   439  					tableName: "lineitem",
   440  					colNames:  []string{"l_partkey", "l_quantity", "l_extendedprice"},
   441  				},
   442  				{
   443  					tableName: "part",
   444  					colNames:  []string{"p_partkey", "p_brand", "p_container"},
   445  				},
   446  				{
   447  					tableName: "lineitem",
   448  					colNames:  []string{"l_partkey", "l_quantity"},
   449  				},
   450  			},
   451  		},
   452  	}
   453  
   454  	for _, c := range cases {
   455  		t.Run(c.name, func(t *testing.T) {
   456  			mock := plan2.NewMockOptimizer(false)
   457  			logicPlan, err := buildOneStmt(mock, t, c.sql)
   458  			if err != nil {
   459  				t.Fatalf("%+v", err)
   460  			}
   461  			columns, err := getPrunedTableColumns(logicPlan)
   462  			if err != nil {
   463  				t.Fatalf("%+v", err)
   464  			}
   465  			require.Equal(t, c.wantTableCol, columns)
   466  		})
   467  	}
   468  
   469  }
   470  
   471  func TestDerivedTableQueryPrune(t *testing.T) {
   472  	cases := []struct {
   473  		name         string
   474  		sql          string
   475  		wantTableCol []Entry[string, []string]
   476  	}{
   477  		{
   478  			name: "Test01",
   479  			sql:  "select c_custkey from (select c_custkey from CUSTOMER group by c_custkey ) a",
   480  			wantTableCol: []Entry[string, []string]{
   481  				{
   482  					tableName: "customer",
   483  					colNames:  []string{"c_custkey"},
   484  				},
   485  			},
   486  		},
   487  		{
   488  			name: "Test02",
   489  			sql:  "select c_custkey from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey ) a where ff > 0 order by c_custkey",
   490  			wantTableCol: []Entry[string, []string]{
   491  				{
   492  					tableName: "customer",
   493  					colNames:  []string{"c_custkey", "c_nationkey"},
   494  				},
   495  			},
   496  		},
   497  		{
   498  			name: "Test03",
   499  			sql:  "select c_custkey from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey ) a join NATION b on a.c_custkey = b.N_REGIONKEY where b.N_NATIONKEY > 10",
   500  			wantTableCol: []Entry[string, []string]{
   501  				{
   502  					tableName: "customer",
   503  					colNames:  []string{"c_custkey", "c_nationkey"},
   504  				},
   505  				{
   506  					tableName: "nation",
   507  					colNames:  []string{"n_nationkey", "n_regionkey"},
   508  				},
   509  			},
   510  		},
   511  		{
   512  			name: "Test04",
   513  			sql:  "select a.* from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey ) a join NATION b on a.c_custkey = b.N_REGIONKEY where b.N_NATIONKEY > 10",
   514  			wantTableCol: []Entry[string, []string]{
   515  				{
   516  					tableName: "customer",
   517  					colNames:  []string{"c_custkey", "c_nationkey"},
   518  				},
   519  				{
   520  					tableName: "nation",
   521  					colNames:  []string{"n_nationkey", "n_regionkey"},
   522  				},
   523  			},
   524  		},
   525  		{
   526  			name: "Test05",
   527  			sql:  "select * from (select c_custkey, count(C_NATIONKEY) ff from CUSTOMER group by c_custkey ) a join NATION b on a.c_custkey = b.N_REGIONKEY where b.N_NATIONKEY > 10",
   528  			wantTableCol: []Entry[string, []string]{
   529  				{
   530  					tableName: "customer",
   531  					colNames:  []string{"c_custkey", "c_nationkey"},
   532  				},
   533  				{
   534  					tableName: "nation",
   535  					colNames:  []string{"n_nationkey", "n_name", "n_regionkey", "n_comment"},
   536  				},
   537  			},
   538  		},
   539  	}
   540  
   541  	for _, c := range cases {
   542  		t.Run(c.name, func(t *testing.T) {
   543  			mock := plan2.NewMockOptimizer(false)
   544  			logicPlan, err := buildOneStmt(mock, t, c.sql)
   545  			if err != nil {
   546  				t.Fatalf("%+v", err)
   547  			}
   548  			columns, err := getPrunedTableColumns(logicPlan)
   549  			if err != nil {
   550  				t.Fatalf("%+v", err)
   551  			}
   552  			require.Equal(t, c.wantTableCol, columns)
   553  		})
   554  	}
   555  
   556  }
   557  
   558  func buildOneStmt(opt plan2.Optimizer, t *testing.T, sql string) (*plan.Plan, error) {
   559  	stmts, err := mysql.Parse(opt.CurrentContext().GetContext(), sql, 1, 0)
   560  	if err != nil {
   561  		t.Fatalf("%+v", err)
   562  	}
   563  	// this sql always return one stmt
   564  	ctx := opt.CurrentContext()
   565  	return plan2.BuildPlan(ctx, stmts[0], false)
   566  }
   567  
   568  type Entry[K any, V any] struct {
   569  	tableName K
   570  	colNames  V
   571  }
   572  
   573  // Get the remaining columns after the table is cropped
   574  func getPrunedTableColumns(logicPlan *plan.Plan) ([]Entry[string, []string], error) {
   575  	query := logicPlan.GetQuery()
   576  	if query.StmtType != plan.Query_SELECT {
   577  		return nil, moerr.NewNotSupported(context.TODO(), "SQL is not a DQL")
   578  	}
   579  
   580  	res := make([]Entry[string, []string], 0)
   581  	for _, node := range query.Nodes {
   582  		if node.NodeType == plan.Node_TABLE_SCAN {
   583  			tableDef := node.TableDef
   584  			tableName := strings.ToLower(tableDef.Name)
   585  
   586  			columns := make([]string, 0)
   587  			for _, col := range tableDef.Cols {
   588  				columns = append(columns, strings.ToLower(col.Name))
   589  			}
   590  			entry := Entry[string, []string]{
   591  				tableName: tableName,
   592  				colNames:  columns,
   593  			}
   594  			res = append(res, entry)
   595  		}
   596  	}
   597  	return res, nil
   598  }